diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2020-07-16 11:45:35 +0200 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2020-07-17 08:59:23 +0000 |
commit | 552906b0f222c5d5dd11b9fd73829d510980461a (patch) | |
tree | 3a11e6ed0538a81dd83b20cf3a4783e297f26d91 /chromium/third_party/openscreen | |
parent | 1b05827804eaf047779b597718c03e7d38344261 (diff) | |
download | qtwebengine-chromium-552906b0f222c5d5dd11b9fd73829d510980461a.tar.gz |
BASELINE: Update Chromium to 83.0.4103.122
Change-Id: Ie3a82f5bb0076eec2a7c6a6162326b4301ee291e
Reviewed-by: Michael Brüning <michael.bruning@qt.io>
Diffstat (limited to 'chromium/third_party/openscreen')
579 files changed, 31941 insertions, 7057 deletions
diff --git a/chromium/third_party/openscreen/src/BUILD.gn b/chromium/third_party/openscreen/src/BUILD.gn index 45c89852612..882422618f7 100644 --- a/chromium/third_party/openscreen/src/BUILD.gn +++ b/chromium/third_party/openscreen/src/BUILD.gn @@ -5,31 +5,43 @@ import("//build_overrides/build.gni") import("osp/build/config/services.gni") +declare_args() { + # Set to true to force building the standalone receiver on Mac. It's currently + # disabled due to build bot struggles, but works fine on local, recent clang + # installations. + # TODO(crbug.com/openscreen/86): Remove when the Mac bots have been upgraded. + force_build_standalone_receiver = false +} + # All compilable non-test targets in the repository (both executables and # source_sets). group("gn_all") { + testonly = true + deps = [ "cast/common:certificate", "cast/common:channel", + "cast/common:public", + "cast/receiver:channel", "cast/sender:channel", "cast/streaming:receiver", "cast/streaming:sender", + "discovery:common", "discovery:dnssd", "discovery:mdns", + "discovery:public", "osp", "osp/msgs", "platform", "third_party/abseil", "third_party/boringssl", "third_party/jsoncpp", + "third_party/mozilla", "third_party/tinycbor", + "tools/cddl($host_toolchain)", "util", ] - if (current_toolchain == host_toolchain) { - deps += [ "tools/cddl" ] - } - if (use_mdns_responder) { deps += [ "osp/impl/discovery/mdns:mdns_demo" ] } @@ -37,8 +49,8 @@ group("gn_all") { if (use_chromium_quic) { deps += [ "third_party/chromium_quic", - "third_party/chromium_quic:quic_demo_client", "third_party/chromium_quic:quic_demo_server", + "third_party/chromium_quic:quic_streaming_playback_controller", ] } @@ -48,14 +60,21 @@ group("gn_all") { if (!build_with_chromium) { deps += [ - "third_party/protobuf:protoc", + "third_party/protobuf:protoc($host_toolchain)", "third_party/zlib", ] + if (is_posix) { + deps += [ "cast/test:make_crl_tests($host_toolchain)" ] + } + # TODO(crbug.com/openscreen/86): Build for Mac too once the mac buildbot # compiler is upgraded. - if (!is_mac) { - deps += [ "cast/standalone_receiver:cast_receiver" ] + if (!is_mac || force_build_standalone_receiver) { + deps += [ + "cast/standalone_receiver:cast_receiver", + "cast/standalone_sender:cast_sender", + ] } } } @@ -64,8 +83,10 @@ source_set("openscreen_unittests_all") { testonly = true public_deps = [ "cast/common:unittests", + "cast/receiver:unittests", "cast/sender:unittests", "cast/streaming:unittests", + "cast/test:unittests", "discovery:unittests", "osp:unittests", "osp/msgs:unittests", @@ -93,3 +114,21 @@ if (!build_with_chromium) { ] } } + +if (!build_with_chromium && is_posix) { + source_set("e2e_tests_all") { + testonly = true + public_deps = [ + "cast/common:discovery_e2e_test", + "cast/test:e2e_tests", + "third_party/googletest:gtest_main", + ] + } + + executable("e2e_tests") { + testonly = true + deps = [ + ":e2e_tests_all", + ] + } +} diff --git a/chromium/third_party/openscreen/src/COMMITTERS b/chromium/third_party/openscreen/src/COMMITTERS index 48b2a70ffb9..770a6ccd23f 100644 --- a/chromium/third_party/openscreen/src/COMMITTERS +++ b/chromium/third_party/openscreen/src/COMMITTERS @@ -5,6 +5,5 @@ btolsch@chromium.org # Additional reviewers jopbha@chromium.org miu@chromium.org -pthatcher@chromium.org rwkeane@chromium.org yakimakha@chromium.org diff --git a/chromium/third_party/openscreen/src/DEPS b/chromium/third_party/openscreen/src/DEPS index 9955ada4180..e4a34906c7e 100644 --- a/chromium/third_party/openscreen/src/DEPS +++ b/chromium/third_party/openscreen/src/DEPS @@ -8,140 +8,188 @@ # to list the dependency's destination directory. use_relative_paths = True +use_relative_hooks = True vars = { - 'boringssl_git': 'https://boringssl.googlesource.com', - 'chromium_git': 'https://chromium.googlesource.com', + 'boringssl_git': 'https://boringssl.googlesource.com', + 'chromium_git': 'https://chromium.googlesource.com', - # TODO(jophba): move to googlesource external for github repos. - 'github': 'https://github.com', + # TODO(jophba): move to googlesource external for github repos. + 'github': 'https://github.com', - # NOTE: Strangely enough, this will be overridden by any _parent_ DEPS, so - # in Chromium it will correctly be True. - 'build_with_chromium': False, + # NOTE: Strangely enough, this will be overridden by any _parent_ DEPS, so + # in Chromium it will correctly be True. + 'build_with_chromium': False, - 'gn_version': 'git_revision:0790d3043387c762a6bacb1ae0a9ebe883188ab2', - 'checkout_chromium_quic_boringssl': False, + 'checkout_chromium_quic_boringssl': False, - # By default, do not check out openscreen/cast. This can be overridden - # by custom_vars in .gclient. - 'checkout_openscreen_cast_internal': False + # Needed to download additional clang binaries for processing coverage data + # (from binaries with GN arg `use_coverage=true`). + 'checkout_clang_coverage_tools': False, } deps = { - 'cast/internal': { - 'url': 'https://chrome-internal.googlesource.com/openscreen/cast.git' + - '@' + '703984f9d1674c2cfc259904a5a7fba4990cca4b', - 'condition': 'checkout_openscreen_cast_internal', - }, - - 'buildtools': { - 'url': Var('chromium_git')+ '/chromium/src/buildtools' + - '@' + '140e4d7c45ffb55ce5dc4d11a0c3938363cd8257', - 'condition': 'not build_with_chromium', - }, - - 'third_party/protobuf/src': { - 'url': Var('chromium_git') + - '/external/github.com/protocolbuffers/protobuf.git' + - '@' + 'd09d649aea36f02c03f8396ba39a8d4db8a607e4', # version 3.10.1 - 'condition': 'not build_with_chromium', - }, - - 'third_party/zlib/src': { - 'url': Var('github') + - '/madler/zlib.git' + - '@' + 'cacf7f1d4e3d44d871b605da3b647f07d718623f', # version 1.2.11 - 'condition': 'not build_with_chromium', - }, - - 'third_party/jsoncpp/src': { - 'url': Var('chromium_git') + - '/external/github.com/open-source-parsers/jsoncpp.git' + - '@' + '2eb20a938c454411c1d416caeeb2a6511daab5cb', # version 1.9.0 - 'condition': 'not build_with_chromium', - }, - - 'third_party/googletest/src': { - 'url': Var('chromium_git') + - '/external/github.com/google/googletest.git' + - '@' + '8697709e0308af4cd5b09dc108480804e5447cf0', - 'condition': 'not build_with_chromium', - }, - - 'third_party/mDNSResponder/src': { - 'url': Var('github') + '/jevinskie/mDNSResponder.git' + - '@' + '2942dde61f920fbbf96ff9a3840567ebbe7cb1b6', - 'condition': 'not build_with_chromium', - }, - - 'third_party/boringssl/src': { - 'url' : Var('boringssl_git') + '/boringssl.git' + - '@' + '6410e18e9190b6b0c71955119fbf3cae1b9eedb7', - 'condition': 'not build_with_chromium', - }, - - 'third_party/chromium_quic/src': { - 'url': Var('chromium_git') + '/openscreen/quic.git' + - '@' + 'b73bd98ac9eaedf01a732b1933f97112cf247d93', - 'condition': 'not build_with_chromium', - }, - - 'third_party/tinycbor/src': - Var('chromium_git') + '/external/github.com/intel/tinycbor.git' + - '@' + '755f9ef932f9830a63a712fd2ac971d838b131f1', - - 'third_party/abseil/src': { - 'url': Var('chromium_git') + - '/external/github.com/abseil/abseil-cpp.git' + - '@' + '20de2db748ca0471cfb61cb53e813dd12938c12b', - 'condition': 'not build_with_chromium', - }, + # NOTE: This commit hash here references a repository/branch that is a mirror + # of the commits to the buildtools directory in the Chromium repository. This + # should be regularly updated with the tip of the MIRRORED master branch, + # found here: + # https://chromium.googlesource.com/chromium/src/buildtools/+/refs/heads/master. + 'buildtools': { + 'url': Var('chromium_git')+ '/chromium/src/buildtools' + + '@' + '8d2132841536523249669813b928e29144d487f9', + 'condition': 'not build_with_chromium', + }, + + 'third_party/protobuf/src': { + 'url': Var('chromium_git') + + '/external/github.com/protocolbuffers/protobuf.git' + + '@' + 'd09d649aea36f02c03f8396ba39a8d4db8a607e4', # version 3.10.1 + 'condition': 'not build_with_chromium', + }, + + 'third_party/zlib/src': { + 'url': Var('github') + + '/madler/zlib.git' + + '@' + 'cacf7f1d4e3d44d871b605da3b647f07d718623f', # version 1.2.11 + 'condition': 'not build_with_chromium', + }, + + 'third_party/jsoncpp/src': { + 'url': Var('chromium_git') + + '/external/github.com/open-source-parsers/jsoncpp.git' + + '@' + '2eb20a938c454411c1d416caeeb2a6511daab5cb', # version 1.9.0 + 'condition': 'not build_with_chromium', + }, + + 'third_party/googletest/src': { + 'url': Var('chromium_git') + + '/external/github.com/google/googletest.git' + + '@' + '8697709e0308af4cd5b09dc108480804e5447cf0', + 'condition': 'not build_with_chromium', + }, + + 'third_party/mDNSResponder/src': { + 'url': Var('github') + '/jevinskie/mDNSResponder.git' + + '@' + '2942dde61f920fbbf96ff9a3840567ebbe7cb1b6', + 'condition': 'not build_with_chromium', + }, + + 'third_party/boringssl/src': { + 'url' : Var('boringssl_git') + '/boringssl.git' + + '@' + '6410e18e9190b6b0c71955119fbf3cae1b9eedb7', + 'condition': 'not build_with_chromium', + }, + + 'third_party/chromium_quic/src': { + 'url': Var('chromium_git') + '/openscreen/quic.git' + + '@' + 'd2363edc9f2a8561c9d02e836262f2d03de2d6e1', + 'condition': 'not build_with_chromium', + }, + + 'third_party/tinycbor/src': + Var('chromium_git') + '/external/github.com/intel/tinycbor.git' + + '@' + '755f9ef932f9830a63a712fd2ac971d838b131f1', + + 'third_party/abseil/src': { + 'url': Var('chromium_git') + + '/external/github.com/abseil/abseil-cpp.git' + + '@' + '20de2db748ca0471cfb61cb53e813dd12938c12b', + 'condition': 'not build_with_chromium', + }, + 'third_party/libfuzzer/src': { + 'url': Var('chromium_git') + + '/chromium/llvm-project/compiler-rt/lib/fuzzer.git' + + '@' + 'debe7d2d1982e540fbd6bd78604bf001753f9e74', + 'condition': 'not build_with_chromium', + }, } +hooks = [ + { + 'name': 'clang_update_script', + 'pattern': '.', + 'condition': 'not build_with_chromium', + 'action': [ 'python', 'tools/download-clang-update-script.py', + '--output', 'tools/clang/scripts/update.py' ], + # NOTE: This file appears in .gitignore, as it is not a part of the + # openscreen repo. + }, + { + 'name': 'update_clang', + 'pattern': '.', + 'condition': 'not build_with_chromium', + 'action': [ 'python', 'tools/clang/scripts/update.py' ], + }, + { + 'name': 'clang_coverage_tools', + 'pattern': '.', + 'condition': 'not build_with_chromium and checkout_clang_coverage_tools', + 'action': ['python', 'tools/clang/scripts/update.py', + '--package=coverage_tools'], + }, + { + 'name': 'clang_format_linux64', + 'pattern': '.', + 'action': [ 'download_from_google_storage.py', '--no_resume', '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'buildtools/linux64/clang-format.sha1' ], + 'condition': 'host_os == "linux" and not build_with_chromium', + }, + { + 'name': 'clang_format_mac', + 'pattern': '.', + 'action': [ 'download_from_google_storage.py', '--no_resume', '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'buildtools/mac/clang-format.sha1' ], + 'condition': 'host_os == "mac" and not build_with_chromium', + }, +] + recursedeps = [ - 'third_party/chromium_quic/src', - 'buildtools', + 'third_party/chromium_quic/src', + 'buildtools', ] include_rules = [ - '+build/config/features.h', - '+util', - '+platform/api', - '+platform/base', - '+platform/test', - '+third_party', - - # Don't include abseil from the root so the path can change via include_dirs - # rules when in Chromium. - '-third_party/abseil', - - # Abseil whitelist. - '+absl/algorithm/container.h', - '+absl/base/thread_annotations.h', - '+absl/hash/hash.h', - '+absl/strings/ascii.h', - '+absl/strings/match.h', - '+absl/strings/numbers.h', - '+absl/strings/str_cat.h', - '+absl/strings/str_join.h', - '+absl/strings/str_split.h', - '+absl/strings/string_view.h', - '+absl/strings/substitute.h', - '+absl/types/optional.h', - '+absl/types/span.h', - '+absl/types/variant.h', - - # Similar to abseil, don't include boringssl using root path. Instead, - # explicitly allow 'openssl' where needed. - '-third_party/boringssl', - - # Test framework includes. - "-third_party/googletest", - "+gtest", - "+gmock", + '+build/config/features.h', + '+util', + '+platform/api', + '+platform/base', + '+platform/test', + '+testing/util', + '+third_party', + + # Don't include abseil from the root so the path can change via include_dirs + # rules when in Chromium. + '-third_party/abseil', + + # Abseil whitelist. + '+absl/algorithm/container.h', + '+absl/base/thread_annotations.h', + '+absl/hash/hash.h', + '+absl/strings/ascii.h', + '+absl/strings/match.h', + '+absl/strings/numbers.h', + '+absl/strings/str_cat.h', + '+absl/strings/str_join.h', + '+absl/strings/str_replace.h', + '+absl/strings/str_split.h', + '+absl/strings/string_view.h', + '+absl/strings/substitute.h', + '+absl/types/optional.h', + '+absl/types/span.h', + '+absl/types/variant.h', + + # Similar to abseil, don't include boringssl using root path. Instead, + # explicitly allow 'openssl' where needed. + '-third_party/boringssl', + + # Test framework includes. + "-third_party/googletest", + "+gtest", + "+gmock", ] skip_child_includes = [ - 'third_party/chromium_quic', + 'third_party/chromium_quic', ] diff --git a/chromium/third_party/openscreen/src/PRESUBMIT.sh b/chromium/third_party/openscreen/src/PRESUBMIT.sh index 8181f11550f..6c4e9009f74 100755 --- a/chromium/third_party/openscreen/src/PRESUBMIT.sh +++ b/chromium/third_party/openscreen/src/PRESUBMIT.sh @@ -59,11 +59,6 @@ if [[ "$invoker" != 'python' ]]; then echo "This shouldn't be invoked directly, please use \`git cl presubmit\`." fi -# TODO(jophba): check in a better fix for the build bots. -if command -v clang-format &> /dev/null; then - tools/install-build-tools.sh &> /dev/null -fi - for f in $(git diff --name-only --diff-filter=d origin/master); do # Skip third party files, except our custom BUILD.gns if [[ $f =~ third_party/[^\/]*/src ]]; then diff --git a/chromium/third_party/openscreen/src/README.md b/chromium/third_party/openscreen/src/README.md index 17fb84378c8..a03d6d419ab 100644 --- a/chromium/third_party/openscreen/src/README.md +++ b/chromium/third_party/openscreen/src/README.md @@ -26,22 +26,28 @@ lint` and `git cl upload.` ## Checking out code -From the parent directory of where you want the openscreen checkout, configure -`gclient` and check out openscreen with the following commands: +From the parent directory of where you want the openscreen checkout (e.g., +`~/my_project_dir`), configure `gclient` and check out openscreen with the +following commands: ```bash + cd ~/my_project_dir gclient config https://chromium.googlesource.com/openscreen gclient sync ``` -Now, you should have `openscreen/` repository checked out, with all dependencies -checked out to their appropriate revisions. +The first `gclient` command will create a default .gclient file in +`~/my_project_dir` that describes how to pull down the `openscreen` repository. +The second command creates an `openscreen/` subdirectory, downloads the source +code, all third-party dependencies, and the toolchain needed to build things; +and at their appropriate revisions. ## Syncing your local checkout To update your local checkout from the openscreen master repository, just run ```bash + cd ~/my_project_dir/openscreen git pull gclient sync ``` @@ -51,24 +57,23 @@ dependencies that have changed. # Build setup -## Installing build dependencies - -The following tools are required for building: +The following are the main tools are required for development/builds: - Build file generator: `gn` - - Code formatter (optional): `clang-format` + - Code formatter: `clang-format` - Builder: `ninja` ([GitHub releases](https://github.com/ninja-build/ninja/releases)) + - Compiler/Linker: `clang` (installed by default) or `gcc` (installed by you) -`clang-format` and `ninja` can be downloaded to `buildtools/<platform>` root by -running `./tools/install-build-tools.sh`. - -`clang-format` is only used for presubmit checks and optionally used on -generated code from the CDDL tool. +All of these--except `gcc` as noted above--are automatically downloaded/updated +for the Linux and Mac environments via `gclient sync` as described above. The +first two are installed into `buildtools/<platform>/`. -`gn` will be installed in `buildtools/<platform>/` automatically by `gclient sync`. +Mac only: XCode must be installed on the system, to link against its frameworks. -You also need to ensure that you have the compiler and its toolchain dependencies. -Currently, both Linux and Mac OS X build configurations use clang by default. +`clang-format` is used for maintaining consistent coding style, but it is not a +complete replacement for adhering to Chromium/Google C++ style (that's on you!). +The presubmit script will sanity-check that it has been run on all new/changed +code. ## Linux clang @@ -153,6 +158,11 @@ the working directory for the build. So the same could be done as follows: After editing a file, only `ninja` needs to be rerun, not `gn`. If you have edited a `BUILD.gn` file, `ninja` will re-run `gn` for you. +Unless you like to wait longer than necessary for builds to complete, run +`autoninja` instead of `ninja`, which takes the same command-line arguments. +This will automatically parallelize the build for your system, depending on +number of processor cores, RAM, etc. + For details on running `demo`, see its [README.md](demo/README.md). ## Building other targets @@ -174,6 +184,25 @@ the build flags available. ./out/Default/unittests ``` +## Building and running fuzzers + +In order to build fuzzers, you need the GN arg `use_libfuzzer=true`. It's also +recommended to build with `is_asan=true` to catch additional problems. Building +and running then might look like: +```bash + gn gen out/libfuzzer --args="use_libfuzzer=true is_asan=true is_debug=false" + ninja -C out/libfuzzer some_fuzz_target + out/libfuzzer/some_fuzz_target <args> <corpus_dir> [additional corpus dirs] +``` + +The arguments to the fuzzer binary should be whatever is listed in the GN target +description (e.g. `-max_len=1500`). These arguments may be automatically +scraped by Chromium's ClusterFuzz tool when it runs fuzzers, but they are not +built into the target. You can also look at the file +`out/libfuzzer/some_fuzz_target.options` for what arguments should be used. The +`corpus_dir` is listed as `seed_corpus` in the GN definition of the fuzzer +target. + # Continuous build and try jobs openscreen uses [LUCI builders](https://ci.chromium.org/p/openscreen/builders) @@ -213,11 +242,13 @@ review tool) and is recommended for pushing patches for review. Once you have committed changes locally, simply run: ```bash + git cl format git cl upload ``` -This will run our `PRESUBMIT.sh` script to check style, and if it passes, a new -code review will be posted on `chromium-review.googlesource.com`. +The first command will will auto-format the code changes. Then, the second +command runs the `PRESUBMIT.sh` script to check style and, if it passes, a +newcode review will be posted on `chromium-review.googlesource.com`. If you make additional commits to your local branch, then running `git cl upload` again in the same branch will merge those commits into the ongoing @@ -256,3 +287,76 @@ After your patch has received one or more LGTM commit it by clicking the `SUBMIT` button (or, confusingly, `COMMIT QUEUE +2`) in Gerrit. This will run your patch through the builders again before committing to the main openscreen repository. + +<!-- TODO(mfoltz): split up README.md into more manageable files. --> +## Working with ARM/ARM64/the Raspberry PI + +openscreen supports cross compilation for both arm32 and arm64 platforms, by +using the `gn args` parameter `target_cpu="arm"` or `target_cpu="arm64"` +respectively. Note that quotes are required around the target arch value. + +Setting an arm(64) target_cpu causes GN to pull down a sysroot from openscreen's +public cloud storage bucket. Google employees may update the sysroots stored +by requesting access to the Open Screen pantheon project and uploading a new +tar.xz to the openscreen-sysroots bucket. + +NOTE: The "arm" image is taken from Chromium's debian arm image, however it has +been manually patched to include support for libavcodec and libsdl2. To update +this image, the new image must be manually patched to include the necessary +header and library dependencies. Note that if the versions of libavcodec and +libsdl2 are too out of sync from the copies in the sysroot, compilation will +succeed, but you may experience issues decoding content. + +To install the last known good version of the libavcodec and libsdl packages +on a Raspberry Pi, you can run the following command: + +```bash +sudo ./cast/standalone_receiver/install_demo_deps_raspian.sh +``` + +NOTE: until [Issue 106](http://crbug.com/openscreen/106) is resolved, you may +experience issues streaming to a Raspberry Pi if multiple network interfaces +(e.g. WiFi + Ethernet) are enabled. The workaround is to disable either the WiFi +or ethernet connection. + +## Code Coverage + +Code coverage can be checked using clang's source-based coverage tools. You +must use the GN argument `use_coverage=true`. It's recommended to do this in a +separate output directory since the added instrumentation will affect +performance and generate an output file every time a binary is run. You can +read more about this in [clang's +documentation](http://clang.llvm.org/docs/SourceBasedCodeCoverage.html) but the +bare minimum steps are also outlined below. You will also need to download the +pre-built clang coverage tools, which are not downloaded by default. The +easiest way to do this is to set a custom variable in your `.gclient` file. +Under the "openscreen" solution, add: +```python + "custom_vars": { + "checkout_clang_coverage_tools": True, + }, +``` +then run `gclient runhooks`. You can also run the python command from the +`clang_coverage_tools` hook in `//DEPS` yourself or even download the tools +manually +([link](https://storage.googleapis.com/chromium-browser-clang-staging/)). + +Once you have your GN directory (we'll call it `out/coverage`) and have +downloaded the tools, do the following to generate an HTML coverage report: +```bash +out/coverage/openscreen_unittests +third_party/llvm-build/Release+Asserts/bin/llvm-profdata merge -sparse default.profraw -o foo.profdata +third_party/llvm-build/Release+Asserts/bin/llvm-cov show out/coverage/openscreen_unittests -instr-profile=foo.profdata -format=html -output-dir=<out dir> [filter paths] +``` +There are a few things to note here: + - `default.profraw` is generated by running the instrumented code, but + `foo.profdata` can be any path you want. + - `<out dir>` should be an empty directory for placing the generated HTML + files. You can view the report at `<out dir>/index.html`. + - `[filter paths]` is a list of paths to which you want to limit the coverage + report. For example, you may want to limit it to cast/ or even + cast/streaming/. If this list is empty, all data will be in the report. + +The same process can be used to check the coverage of a fuzzer's corpus. Just +add `-runs=0` to the fuzzer arguments to make sure it only runs the existing +corpus then exits. diff --git a/chromium/third_party/openscreen/src/build/config/BUILD.gn b/chromium/third_party/openscreen/src/build/config/BUILD.gn index 27f29ed5dd2..ec988061c61 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILD.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILD.gn @@ -2,6 +2,9 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//build/config/arm.gni") +import("//build/config/sysroot.gni") + declare_args() { # Enable OSP_DCHECKs to be compiled in, even if it's not a debug build. dcheck_always_on = false @@ -11,8 +14,16 @@ declare_args() { # Enable thread sanitizer. is_tsan = false + + # Must be enabled for fuzzing targets. + use_libfuzzer = false + + # Enables clang's source-based coverage (requires is_clang=true). + use_coverage = false } +assert(!use_coverage || is_clang) + config("compiler_defaults") { cflags = [] if (is_posix && !is_mac) { @@ -35,6 +46,31 @@ config("compiler_defaults") { } } +config("compiler_cpu_abi") { + cflags = [] + ldflags = [] + + if (current_cpu == "x64") { + # These are explicitly specified in case of cross-compiling. + cflags += [ "-m64" ] + ldflags += [ "-m64" ] + } else if (current_cpu == "x86") { + cflags += [ "-m32" ] + ldflags += [ "-m32" ] + } else if (current_cpu == "arm") { + cflags += [ + "--target=arm-linux-gnueabihf", + "-march=$arm_arch", + "-mfloat-abi=$arm_float_abi", + "-mtune=$arm_tune", + ] + ldflags += [ "--target=arm-linux-gnueabihf" ] + } else if (current_cpu == "arm64") { + cflags += [ "--target=aarch64-linux-gnu" ] + ldflags += [ "--target=aarch64-linux-gnu" ] + } +} + config("no_exceptions") { # -fno-exceptions causes the compiler to choose the implementation of the STL # that uses abort() calls instead of throws, as well as issue compile errors @@ -166,4 +202,54 @@ config("default_sanitizers") { cflags += [ "-fsanitize=thread" ] ldflags += [ "-fsanitize=thread" ] } + + if (use_libfuzzer) { + cflags += [ "-fsanitize=fuzzer-no-link" ] + if (!is_asan) { + ldflags += [ "-fsanitize=address" ] + } + } +} + +config("default_coverage") { + cflags = [] + ldflags = [] + + if (use_coverage) { + cflags += [ + "-fprofile-instr-generate", + "-fcoverage-mapping", + ] + ldflags += [ "-fprofile-instr-generate" ] + } +} + +config("sysroot_runtime_libraries") { + if (sysroot != "") { + # As other sysroot CPU targets get added, they should be checked here. + assert(current_cpu == "arm" || current_cpu == "arm64") + assert(is_clang) + sysroot_path = rebase_path(sysroot, root_build_dir) + flags = [ "--sysroot=" + sysroot_path ] + hash = exec_script("//build/scripts/install-sysroot.py", + [ + "--print-hash", + "$current_cpu", + ], + "trim string", + [ "//build/scripts/sysroots.json" ]) + + # GN uses this to know that the sysroot is "dirty" + defines = [ "SYSROOT_HASH=$hash" ] + ldflags = flags + cflags = flags + + ld_paths = exec_script("//build/scripts/sysroot_ld_path.py", + [ sysroot_path ], + "list lines") + foreach(ld_path, ld_paths) { + ld_path = rebase_path(ld_path, root_build_dir) + ldflags += [ "-L" + ld_path ] + } + } } diff --git a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn index 359ae15c6fb..4e69143f537 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn @@ -68,33 +68,54 @@ declare_args() { # gcc compiler on Linux instead, set is_gcc to true. is_gcc = false clang_base_path = default_clang_base_path + + # This would not normally be set as a build argument, but rather is used as a + # default value during the first parse of this config. All other toolchains + # that cause this file to be re-parsed will already have this set. For + # further explanation, see + # https://gn.googlesource.com/gn/+/refs/heads/master/docs/reference.md#toolchain-overview + host_toolchain = "" } declare_args() { is_clang = !is_gcc } -# We need to ensure that clang is pulled down using the update script. In -# Chromium, this is done with a gclient hook, but we can just call -if (is_clang) { - exec_script("//tools/clang/scripts/update.py") -} - # ============================================================================== # TOOLCHAIN SETUP # ============================================================================== # -# Here we set the default toolchain. Currently only Mac and POSIX are defined. -host_toolchain = "" -if (current_os == "chromeos" || current_os == "linux") { - host_toolchain = "//build/toolchain/linux:linux" -} else if (current_os == "mac") { - host_toolchain = "//build/toolchain/mac:clang" +# Here we set the host and default toolchains. Currently only Mac and POSIX are +# defined. +if (host_toolchain == "") { + if (current_os == "chromeos" || current_os == "linux") { + if (is_clang) { + host_toolchain = "//build/toolchain/linux:clang_$host_cpu" + } else { + host_toolchain = "//build/toolchain/linux:gcc_$host_cpu" + } + } else if (current_os == "mac") { + host_toolchain = "//build/toolchain/mac:clang" + } else { + # TODO(miu): Windows, and others. + assert(false, "Toolchain for current_os is not defined.") + } +} + +_default_toolchain = "" +if (target_os == "chromeos" || target_os == "linux") { + if (is_clang) { + _default_toolchain = "//build/toolchain/linux:clang_$target_cpu" + } else { + _default_toolchain = "//build/toolchain/linux:gcc_$target_cpu" + } +} else if (target_os == "mac") { + assert(host_os == "mac", "Cross-compiling on Mac is not supported.") + _default_toolchain = "//build/toolchain/mac:clang" } else { - # TODO(miu): Windows, and others. - assert(false, "Toolchain for current_os is not defined.") + assert(false, "Toolchain for target_os is not defined.") } -set_default_toolchain(host_toolchain) +set_default_toolchain(_default_toolchain) # ============================================================================= # OS DEFINITIONS @@ -132,8 +153,11 @@ _shared_binary_target_configs = [ "//build/config:no_rtti", "//build/config:symbol_visibility_hidden", "//build/config:default_sanitizers", + "//build/config:default_coverage", "//build/config:compiler_defaults", + "//build/config:compiler_cpu_abi", "//build/config:default_optimization", + "//build/config:sysroot_runtime_libraries", ] # Apply that default list to the binary target types. diff --git a/chromium/third_party/openscreen/src/build/config/arm.gni b/chromium/third_party/openscreen/src/build/config/arm.gni new file mode 100644 index 00000000000..e52b20d14bd --- /dev/null +++ b/chromium/third_party/openscreen/src/build/config/arm.gni @@ -0,0 +1,45 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +declare_args() { + # Version of the ARM processor when compiling on ARM. Ignored on non-ARM + # platforms. + arm_version = 7 + + # The ARM architecture. This will be a string like "armv6" or "armv7-a". + # An empty string means to use the default for the arm_version. Getting + # a proper list of supported architectures is challenging for clang, but + # can be found in the triple.h header in the LLVM source, under the + # SubArchType enum. + arm_arch = "armv7-a" + + # The ARM floating point hardware. This will be a string like "neon" or + # "vfpv3". + arm_fpu = "vfpv3-d16" + + # The ARM floating point mode. This is either the string "hard", "soft", or + # "softfp". + arm_float_abi = "hard" + + # The ARM variant-specific tuning mode. This will be a string like "armv6" + # or "cortex-a15". Each cpu-type has a different tuning value. + arm_tune = "generic-armv7-a" +} + +declare_args() { + # Whether to use the neon FPU instruction set or not. Actual value is set + # below, based on the arm_fpu argument. + arm_use_neon = arm_fpu == "neon" +} + +if (current_cpu == "arm64") { + # arm64 supports only "hard". + arm_float_abi = "hard" + arm_fpu = "neon" + arm_use_neon = true +} + +assert(arm_float_abi == "hard" || arm_float_abi == "soft" || + arm_float_abi == "softfp") +assert(arm_version == 6 || arm_version == 7 || arm_version == 8) diff --git a/chromium/third_party/openscreen/src/build/config/external_libraries.gni b/chromium/third_party/openscreen/src/build/config/external_libraries.gni new file mode 100644 index 00000000000..a451add7e19 --- /dev/null +++ b/chromium/third_party/openscreen/src/build/config/external_libraries.gni @@ -0,0 +1,48 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +declare_args() { + # FFMPEG: If installed on the system, set have_ffmpeg to true. This also + # requires the FFMPEG headers be installed. On Debian-like systems, this can + # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh` + # to install both FFMPEG and libSDL. + have_ffmpeg = false + ffmpeg_libs = [ + "avcodec", + "avformat", + "avutil", + "swresample", + ] + ffmpeg_include_dirs = [] # Add only if headers are at non-standard locations. + ffmpeg_lib_dirs = [] # Add only if libraries are at non-standard locations. + + # libopus: If installed on the system, set have_libopus to true. This also + # requires the libopus headers be installed. For example, on Debian-like + # systems, the following should install everything needed: + # + # sudo apt-get install libopus0 libopus-dev + have_libopus = false + libopus_libs = [ "opus" ] + libopus_include_dirs = [] # Add only if headers are at non-standard locations. + libopus_lib_dirs = [] # Add only if libraries are at non-standard locations. + + # libsdl2: If installed on the system, set have_libsdl2 to true. This also + # requires the libSDL2 headers be installed. On Debian-like systems, this can + # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh` + # to install both FFMPEG and libSDL. + have_libsdl2 = false + libsdl2_libs = [ "SDL2" ] + libsdl2_include_dirs = [] # Add only if headers are at non-standard locations. + libsdl2_lib_dirs = [] # Add only if libraries are at non-standard locations. + + # libvpx: If installed on the system, set have_libvpx to true. This also + # requires the libvpx headers be installed. For example, on Debian-like + # systems, the following should install everything needed: + # + # sudo apt-get install libvpx5 libvpx-dev + have_libvpx = false + libvpx_libs = [ "vpx" ] + libvpx_include_dirs = [] # Add only if headers are at non-standard locations. + libvpx_lib_dirs = [] # Add only if libraries are at non-standard locations. +} diff --git a/chromium/third_party/openscreen/src/build/config/sysroot.gni b/chromium/third_party/openscreen/src/build/config/sysroot.gni new file mode 100644 index 00000000000..1434a5d670d --- /dev/null +++ b/chromium/third_party/openscreen/src/build/config/sysroot.gni @@ -0,0 +1,35 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file + +# This header file defines the "sysroot" variable which is the absolute path +# of the sysroot. If no sysroot applies, the variable will be an empty string. + +declare_args() { + sysroot = "" + + # The relative path to directory containing sysroot images + target_sysroot_dir = "../" + + use_sysroot = current_cpu == "arm" || current_cpu == "arm64" +} + +if (use_sysroot) { + # By default build against a sysroot image downloaded from Cloud Storage + # during gclient runhooks. + if (current_cpu == "arm") { + sysroot = "$target_sysroot_dir/debian_sid_arm-sysroot" + } else if (current_cpu == "arm64") { + sysroot = "$target_sysroot_dir/debian_sid_arm64-sysroot" + } else { + assert(false, "No linux sysroot for cpu: $target_cpu") + } + _script_arch = current_cpu + + if (exec_script("//build/scripts/dir_exists.py", + [ rebase_path(sysroot) ], + "string") != "True") { + print("Missing sysroot for $current_cpu, downloading...") + exec_script("//build/scripts/install-sysroot.py", [ "$current_cpu" ]) + } +} diff --git a/chromium/third_party/openscreen/src/build/scripts/dir_exists.py b/chromium/third_party/openscreen/src/build/scripts/dir_exists.py new file mode 100755 index 00000000000..1e633d22b00 --- /dev/null +++ b/chromium/third_party/openscreen/src/build/scripts/dir_exists.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +""" +Writes True if the argument is a directory. +""" + +from __future__ import print_function + +import os.path +import sys + + +def main(): + print(is_dir(sys.argv[1]), end='') + return 0 + + +def is_dir(dir_name): + return str(os.path.isdir(dir_name)) + + +def DoMain(args): + """Hook to be called from gyp without starting a separate python + interpreter.""" + return is_dir(args[0]) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py b/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py new file mode 100755 index 00000000000..8bf70b7092d --- /dev/null +++ b/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python2 + +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +""" +Install Debian sysroots for cross compiling Open Screen. +""" + +# The sysroot is needed to ensure that binaries that get built will run on +# the oldest stable version of Debian that we currently support. +# This script can be run manually but is more often run as part of gclient +# hooks. When run from hooks this script is a no-op on non-linux platforms. + +# The sysroot image could be constructed from scratch based on the current state +# of the Debian archive but for consistency we use a pre-built root image (we +# don't want upstream changes to Debian to affect the build until we +# choose to pull them in). The sysroot images are stored in Chrome's common +# data storage, and the sysroots.json file should be kept in sync with Chrome's +# copy of it. + +from __future__ import print_function + +import hashlib +import json +import platform +import argparse +import os +import re +import shutil +import subprocess +import sys +try: + # For Python 3.0 and later + from urllib.request import urlopen +except ImportError: + # Fall back to Python 2's urllib2 + from urllib2 import urlopen + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(SCRIPT_DIR) +URL_PREFIX = 'https://storage.googleapis.com' +URL_PATH = 'openscreen-sysroots' + +VALID_ARCHS = ('arm', 'arm64') + +ARCH_TRANSLATIONS = { + 'x64': 'amd64', + 'x86': 'i386', +} + +DEFAULT_TARGET_PLATFORM = 'sid' + + +class Error(Exception): + pass + + +def GetSha1(filename): + """Generates a SHA1 hash for validating download. Done in chunks to avoid + excess memory usage.""" + BLOCKSIZE = 1024 * 1024 + sha1 = hashlib.sha1() + with open(filename, 'rb') as f: + chunk = f.read(BLOCKSIZE) + while chunk: + sha1.update(chunk) + chunk = f.read(BLOCKSIZE) + return sha1.hexdigest() + + +def GetSysrootDict(target_platform, target_arch): + """Gets the sysroot information for a given platform and arch from the sysroots.json + file.""" + if target_arch not in VALID_ARCHS: + raise Error('Unknown architecture: %s' % target_arch) + + sysroots_file = os.path.join(SCRIPT_DIR, 'sysroots.json') + sysroots = json.load(open(sysroots_file)) + sysroot_key = '%s_%s' % (target_platform, target_arch) + if sysroot_key not in sysroots: + raise Error('No sysroot for: %s' % (sysroot_key)) + return sysroots[sysroot_key] + + +def DownloadFile(url, local_path): + """Uses urllib to download a remote file into local_path.""" + for _ in range(3): + try: + response = urlopen(url) + with open(local_path, "wb") as f: + f.write(response.read()) + break + except Exception: + pass + else: + raise Error('Failed to download %s' % url) + +def ValidateFile(local_path, expected_sum): + """Generates the SHA1 hash of a local file to compare with an expected hashsum.""" + sha1sum = GetSha1(local_path) + if sha1sum != expected_sum: + raise Error('Tarball sha1sum is wrong.' + 'Expected %s, actual: %s' % (expected_sum, sha1sum)) + +def InstallSysroot(target_platform, target_arch): + """Downloads, validates, unpacks, and installs a sysroot image.""" + sysroot_dict = GetSysrootDict(target_platform, target_arch) + tarball_filename = sysroot_dict['Tarball'] + tarball_sha1sum = sysroot_dict['Sha1Sum'] + + sysroot = os.path.join(PARENT_DIR, sysroot_dict['SysrootDir']) + + url = '%s/%s/%s/%s' % (URL_PREFIX, URL_PATH, tarball_sha1sum, + tarball_filename) + + stamp = os.path.join(sysroot, '.stamp') + if os.path.exists(stamp): + with open(stamp) as s: + if s.read() == url: + return + + if os.path.isdir(sysroot): + shutil.rmtree(sysroot) + os.mkdir(sysroot) + + tarball_path = os.path.join(sysroot, tarball_filename) + DownloadFile(url, tarball_path) + ValidateFile(tarball_path, tarball_sha1sum) + subprocess.check_call(['tar', 'xf', tarball_path, '-C', sysroot]) + os.remove(tarball_path) + + with open(stamp, 'w') as s: + s.write(url) + + +def parse_args(args): + """Parses the passed in arguments into an object.""" + p = argparse.ArgumentParser() + p.add_argument( + 'arch', + nargs=1, + help='Sysroot architecture: %s' % ', '.join(VALID_ARCHS)) + p.add_argument( + '--print-hash', action="store_true", + help='Print the hash of the sysroot for the specified arch.') + + return p.parse_args(args) + + +def main(args): + if not (sys.platform.startswith('linux') or sys.platform == 'darwin'): + print('Unsupported platform. Only Linux and Mac OS X are supported.') + return 1 + + parsed_args = parse_args(args) + arch = ARCH_TRANSLATIONS.get(parsed_args.arch[0], parsed_args.arch[0]) + + if parsed_args.print_hash: + print(GetSysrootDict(DEFAULT_TARGET_PLATFORM, arch)['Sha1Sum']) + + InstallSysroot(DEFAULT_TARGET_PLATFORM, arch) + return 0 + + +if __name__ == '__main__': + try: + sys.exit(main(sys.argv[1:])) + except Error as e: + sys.stderr.write('Installing sysroot error: {}\n'.format(e)) + sys.exit(1) diff --git a/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py b/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py new file mode 100755 index 00000000000..8c65861046f --- /dev/null +++ b/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python + +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file + +# Replacement for the deprecated sysroot_ld_path.sh implementation in Chrome. +""" +Reads etc/ld.so.conf and/or etc/ld.so.conf.d/*.conf and returns the +appropriate linker flags. +""" + +from __future__ import print_function +import argparse +import glob +import os +import sys + +LD_SO_CONF_REL_PATH = "etc/ld.so.conf" +LD_SO_CONF_D_REL_PATH = "etc/ld.so.conf.d" + + +def parse_args(args): + p = argparse.ArgumentParser(__doc__) + p.add_argument('sysroot_path', nargs=1, help='Path to sysroot root folder') + return os.path.abspath(p.parse_args(args).sysroot_path[0]) + + +def process_entry(sysroot_path, entry): + assert (entry.startswith('/')) + print(os.path.join(sysroot_path, entry.strip()[1:])) + +def process_ld_conf_file(sysroot_path, conf_file_path): + with open(conf_file_path, 'r') as f: + for line in f.readlines(): + if line.startswith('#'): + continue + process_entry(sysroot_path, line) + + +def process_ld_conf_folder(sysroot_path, ld_conf_path): + files = glob.glob(os.path.join(ld_conf_path, '*.conf')) + for file in files: + process_ld_conf_file(sysroot_path, file) + + +def process_ld_conf_files(sysroot_path): + conf_path = os.path.join(sysroot_path, LD_SO_CONF_REL_PATH) + conf_d_path = os.path.join(sysroot_path, LD_SO_CONF_D_REL_PATH) + + if os.path.isdir(conf_path): + process_ld_conf_folder(sysroot_path, conf_path) + elif os.path.isdir(conf_d_path): + process_ld_conf_folder(sysroot_path, conf_d_path) + + +def main(args): + sysroot_path = parse_args(args) + process_ld_conf_files(sysroot_path) + + +if __name__ == '__main__': + try: + sys.exit(main(sys.argv[1:])) + except Exception as e: + sys.stderr.write(str(e) + '\n') + sys.exit(1) diff --git a/chromium/third_party/openscreen/src/build/scripts/sysroots.json b/chromium/third_party/openscreen/src/build/scripts/sysroots.json new file mode 100644 index 00000000000..81b98648d0f --- /dev/null +++ b/chromium/third_party/openscreen/src/build/scripts/sysroots.json @@ -0,0 +1,13 @@ +{ + "sid_arm": { + "Sha1Sum": "07ef353ec66ca510a9b24a927db823e44435c764", + "SysrootDir": "debian_sid_arm-sysroot", + "Tarball": "debian_sid_arm-sysroot-demo-libs.tar.xz" + }, + "sid_arm64": { + "Sha1Sum": "2ee3fb715f031df95b079ce6d2c8f5e8ba2cc6e4", + "SysrootDir": "debian_sid_arm64-sysroot", + "Tarball": "debian_sid_arm64_sysroot.tar.xz" + } + +} diff --git a/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn b/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn index e2205142c51..2767c7c41f7 100644 --- a/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn +++ b/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn @@ -2,110 +2,220 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -toolchain("linux") { - prefix = rebase_path("$clang_base_path/bin", root_build_dir) +template("gcc_toolchain") { + toolchain(target_name) { + assert(defined(invoker.ar), "Caller must define ar command.") + assert(defined(invoker.cc), "Caller must define cc command.") + assert(defined(invoker.cxx), "Caller must define cxx command.") + assert(defined(invoker.ld), "Caller must define ld command.") + forward_variables_from(invoker, + [ + "ar", + "cc", + "cxx", + "ld", + ]) + + toolchain_args = { + forward_variables_from(invoker.toolchain_args, "*") + + # The host toolchain needs to be preserved by all secondary toolchains. + # For futher explanation, see + # https://gn.googlesource.com/gn/+/refs/heads/master/docs/reference.md#toolchain-overview + host_toolchain = host_toolchain + } + + lib_switch = "-l" + lib_dir_switch = "-L" + + object_prefix = "{{source_out_dir}}/{{label_name}}." + + tool("cc") { + depfile = "{{output}}.d" + command = "$cc -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}" + depsformat = "gcc" + description = "CC {{output}}" + outputs = [ + "$object_prefix{{source_name_part}}.o", + ] + } + + tool("cxx") { + depfile = "{{output}}.d" + command = "$cxx -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_cc}} -c {{source}} -o {{output}}" + depsformat = "gcc" + description = "CXX {{output}}" + outputs = [ + "$object_prefix{{source_name_part}}.o", + ] + } + + tool("asm") { + depfile = "{{output}}.d" + command = "$cc -MMD -MF $depfile {{defines}} {{include_dirs}} {{asmflags}} -c {{source}} -o {{output}}" + depsformat = "gcc" + description = "ASM {{output}}" + outputs = [ + "$object_prefix{{source_name_part}}.o", + ] + } + + tool("alink") { + rspfile = "{{output}}.rsp" + command = "rm -f {{output}} && $ar rcs {{output}} @$rspfile" + description = "AR {{target_output_name}}{{output_extension}}" + rspfile_content = "{{inputs}}" + outputs = [ + "{{output_dir}}/{{target_output_name}}{{output_extension}}", + ] + default_output_dir = "{{target_out_dir}}" + default_output_extension = ".a" + output_prefix = "lib" + } + + tool("solink") { + soname = "{{target_output_name}}{{output_extension}}" # e.g. "libfoo.so". + sofile = "{{output_dir}}/$soname" + rspfile = soname + ".rsp" + + command = + "$ld -shared {{ldflags}} -o $sofile -Wl,-soname=$soname @$rspfile" + rspfile_content = "-Wl,--whole-archive {{inputs}} {{solibs}} -Wl,--no-whole-archive {{libs}}" + + description = "SOLINK {{output}}" + + # Use this for {{output_extension}} expansions unless a target manually + # overrides it (in which case {{output_extension}} will be what the target + # specifies). + default_output_extension = ".so" + + # Use this for {{output_dir}} expansions unless a target manually overrides + # it (in which case {{output_dir}} will be what the target specifies). + default_output_dir = "{{root_out_dir}}" + + outputs = [ + sofile, + ] + link_output = sofile + depend_output = sofile + output_prefix = "lib" + } + + tool("link") { + outfile = "{{output_dir}}/{{target_output_name}}{{output_extension}}" + rspfile = "$outfile.rsp" + + # These extra ldflags allow an executable to search for shared libraries in + # the current working directory. + additional_executable_ldflags = "-Wl,-rpath=\$ORIGIN/ -Wl,-rpath-link=" + command = "$ld {{ldflags}} $additional_executable_ldflags -o $outfile -Wl,--start-group @$rspfile {{solibs}} -Wl,--end-group {{libs}}" + description = "LINK $outfile" + default_output_dir = "{{root_out_dir}}" + rspfile_content = "{{inputs}}" + outputs = [ + outfile, + ] + } + + tool("stamp") { + command = "touch {{output}}" + description = "STAMP {{output}}" + } + + tool("copy") { + command = "ln -f {{source}} {{output}} 2>/dev/null || (rm -rf {{output}} && cp -af {{source}} {{output}})" + description = "COPY {{source}} {{output}}" + } + } +} - c_command = "$prefix/clang" - cpp_command = "$prefix/clang++" +template("clang_toolchain") { + prefix = rebase_path("$clang_base_path/bin", root_build_dir) - if (is_gcc) { - c_command = "gcc" - cpp_command = "g++" + gcc_toolchain(target_name) { + ar = "$prefix/llvm-ar" + cc = "$prefix/clang" + cxx = "$prefix/clang++" + ld = cxx + toolchain_args = { + forward_variables_from(invoker.toolchain_args, "*") + is_clang = true + } } +} - tool("cc") { - depfile = "{{output}}.d" - command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}" - depsformat = "gcc" - description = "CC {{output}}" - outputs = [ - "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o", - ] +clang_toolchain("clang_x64") { + toolchain_args = { + current_cpu = "x64" + current_os = "linux" } +} - tool("cxx") { - depfile = "{{output}}.d" - command = "$cpp_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_cc}} -c {{source}} -o {{output}}" - depsformat = "gcc" - description = "CXX {{output}}" - outputs = [ - "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o", - ] +clang_toolchain("clang_x86") { + toolchain_args = { + current_cpu = "x86" + current_os = "linux" } +} - tool("asm") { - depfile = "{{output}}.d" - command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{asmflags}} -c {{source}} -o {{output}}" - depsformat = "gcc" - description = "ASM {{output}}" - outputs = [ - "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o", - ] +clang_toolchain("clang_arm") { + toolchain_args = { + current_cpu = "arm" + current_os = "linux" } +} - tool("alink") { - rspfile = "{{output}}.rsp" - command = "rm -f {{output}} && ar rcs {{output}} @$rspfile" - description = "AR {{target_output_name}}{{output_extension}}" - rspfile_content = "{{inputs}}" - outputs = [ - "{{output_dir}}/{{target_output_name}}{{output_extension}}", - ] - default_output_dir = "{{target_out_dir}}" - default_output_extension = ".a" - output_prefix = "lib" +clang_toolchain("clang_arm64") { + toolchain_args = { + current_cpu = "arm64" + current_os = "linux" } +} - tool("solink") { - soname = "{{target_output_name}}{{output_extension}}" # e.g. "libfoo.so". - sofile = "{{output_dir}}/$soname" - rspfile = soname + ".rsp" - - command = "$cpp_command -shared {{ldflags}} -o $sofile -Wl,-soname=$soname @$rspfile" - rspfile_content = "-Wl,--whole-archive {{inputs}} {{solibs}} -Wl,--no-whole-archive {{libs}}" - - description = "SOLINK {{output}}" - - # Use this for {{output_extension}} expansions unless a target manually - # overrides it (in which case {{output_extension}} will be what the target - # specifies). - default_output_extension = ".so" - - # Use this for {{output_dir}} expansions unless a target manually overrides - # it (in which case {{output_dir}} will be what the target specifies). - default_output_dir = "{{root_out_dir}}" - - outputs = [ - sofile, - ] - link_output = sofile - depend_output = sofile - output_prefix = "lib" +gcc_toolchain("gcc_x64") { + ar = "ar" + cc = "gcc" + cxx = "g++" + ld = cxx + toolchain_args = { + current_cpu = "x64" + current_os = "linux" + is_gcc = true } +} - tool("link") { - outfile = "{{target_output_name}}{{output_extension}}" - rspfile = "$outfile.rsp" - - # These extra ldflags allow an executable to search for shared libraries in - # the current working directory. - additional_executable_ldflags = "-Wl,-rpath=\$ORIGIN/ -Wl,-rpath-link=" - command = "$cpp_command {{ldflags}} $additional_executable_ldflags -o $outfile -Wl,--start-group @$rspfile {{solibs}} -Wl,--end-group {{libs}}" - description = "LINK $outfile" - default_output_dir = "{{root_out_dir}}" - rspfile_content = "{{inputs}}" - outputs = [ - outfile, - ] +gcc_toolchain("gcc_x86") { + ar = "ar" + cc = "gcc" + cxx = "g++" + ld = cxx + toolchain_args = { + current_cpu = "x86" + current_os = "linux" + is_gcc = true } +} - tool("stamp") { - command = "touch {{output}}" - description = "STAMP {{output}}" +gcc_toolchain("gcc_arm") { + ar = "ar" + cc = "gcc" + cxx = "g++" + ld = cxx + toolchain_args = { + current_cpu = "arm" + current_os = "linux" + is_gcc = true } +} - tool("copy") { - command = "ln -f {{source}} {{output}} 2>/dev/null || (rm -rf {{output}} && cp -af {{source}} {{output}})" - description = "COPY {{source}} {{output}}" +gcc_toolchain("gcc_arm64") { + ar = "ar" + cc = "gcc" + cxx = "g++" + ld = cxx + toolchain_args = { + current_cpu = "arm64" + current_os = "linux" + is_gcc = true } } diff --git a/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn b/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn index f56dd499c7f..341d6a6356e 100644 --- a/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn +++ b/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn @@ -6,6 +6,9 @@ toolchain("clang") { c_command = "clang" cpp_command = "clang++" + lib_switch = "-l" + lib_dir_switch = "-L" + tool("cc") { depfile = "{{output}}.d" command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}" diff --git a/chromium/third_party/openscreen/src/cast/DEPS b/chromium/third_party/openscreen/src/cast/DEPS index e8615f96677..4e73fd07bcb 100644 --- a/chromium/third_party/openscreen/src/cast/DEPS +++ b/chromium/third_party/openscreen/src/cast/DEPS @@ -4,14 +4,12 @@ include_rules = [ # OSP code is strictly verboten. '-api', '-demo', - '-discovery', '-go', '-msgs', # Intra-libcast dependencies must be explicit. '-cast', - # All libcast code can use platform and cast/third_party. + # All libcast code can use cast/third_party. '+cast/third_party', - '+platform' ] diff --git a/chromium/third_party/openscreen/src/cast/common/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/BUILD.gn index 80260eb65b1..f753198c4fe 100644 --- a/chromium/third_party/openscreen/src/cast/common/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/common/BUILD.gn @@ -4,6 +4,7 @@ import("//build_overrides/build.gni") import("//third_party/protobuf/proto_library.gni") +import("../../testing/libfuzzer/fuzzer_test.gni") source_set("certificate") { sources = [ @@ -13,6 +14,8 @@ source_set("certificate") { "certificate/cast_cert_validator_internal.h", "certificate/cast_crl.cc", "certificate/cast_crl.h", + "certificate/cast_trust_store.cc", + "certificate/cast_trust_store.h", "certificate/types.cc", "certificate/types.h", ] @@ -32,9 +35,14 @@ source_set("channel") { sources = [ "channel/cast_socket.cc", "channel/cast_socket.h", + "channel/connection_namespace_handler.cc", + "channel/connection_namespace_handler.h", "channel/message_framer.cc", "channel/message_framer.h", + "channel/message_util.cc", "channel/message_util.h", + "channel/namespace_router.cc", + "channel/namespace_router.h", "channel/virtual_connection.h", "channel/virtual_connection_manager.cc", "channel/virtual_connection_manager.h", @@ -43,29 +51,66 @@ source_set("channel") { ] deps = [ - "../../util", "certificate/proto:certificate_proto", - "channel/proto:channel_proto", ] public_deps = [ "../../platform", "../../third_party/abseil", + "../../util", + "channel/proto:channel_proto", + ] +} + +source_set("public") { + sources = [ + "public/service_info.cc", + "public/service_info.h", ] + + deps = [ + "../../discovery:dnssd", + "../../discovery:public", + "../../platform", + "../../third_party/abseil", + ] +} + +if (!build_with_chromium) { + source_set("discovery_e2e_test") { + testonly = true + + sources = [ + "discovery/e2e_test/tests.cc", + ] + + deps = [ + ":public", + "../../discovery:dnssd", + "../../discovery:public", + "../../third_party/googletest:gtest", + ] + } } source_set("test_helpers") { testonly = true + sources = [ - "certificate/test_helpers.cc", - "certificate/test_helpers.h", - "channel/test/fake_cast_socket.h", - "channel/test/mock_cast_message_handler.h", + "certificate/testing/test_helpers.cc", + "certificate/testing/test_helpers.h", + "channel/testing/fake_cast_socket.h", + "channel/testing/mock_cast_message_handler.h", + "channel/testing/mock_socket_error_handler.h", + "public/testing/discovery_utils.cc", + "public/testing/discovery_utils.h", ] public_deps = [ ":certificate", ":channel", + ":public", "../../platform:test", + "../../third_party/abseil", "../../third_party/boringssl", "../../third_party/googletest:gmock", ] @@ -81,16 +126,21 @@ source_set("unittests") { "certificate/cast_cert_validator_unittest.cc", "certificate/cast_crl_unittest.cc", "channel/cast_socket_unittest.cc", + "channel/connection_namespace_handler_unittest.cc", "channel/message_framer_unittest.cc", + "channel/namespace_router_unittest.cc", "channel/virtual_connection_manager_unittest.cc", "channel/virtual_connection_router_unittest.cc", + "public/service_info_unittest.cc", ] deps = [ ":certificate", ":channel", + ":public", ":test_helpers", "../../platform", + "../../testing/util", "../../third_party/boringssl", "../../third_party/googletest:gmock", "../../third_party/googletest:gtest", @@ -98,4 +148,22 @@ source_set("unittests") { "certificate/proto:certificate_unittest_proto", "channel/proto:channel_proto", ] + + data = [ + "../../test/data/cast/common/certificate", + ] +} + +openscreen_fuzzer_test("message_framer_fuzzer") { + sources = [ + "channel/message_framer_fuzzer.cc", + ] + deps = [ + ":channel", + ] + + seed_corpus = "channel/message_framer_fuzzer_seeds" + + # NOTE: 65536 is max _body_ size. + libfuzzer_options = [ "max_len=65600" ] } diff --git a/chromium/third_party/openscreen/src/cast/common/DEPS b/chromium/third_party/openscreen/src/cast/common/DEPS index e1023950518..c31b1081666 100644 --- a/chromium/third_party/openscreen/src/cast/common/DEPS +++ b/chromium/third_party/openscreen/src/cast/common/DEPS @@ -2,5 +2,7 @@ include_rules = [ # libcast common code must depend on neither the sender nor the receiver. - '+cast/common' + '+cast/common', + '+discovery/common', + '+discovery/public', ] diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc index 6fe9582140f..6d2a12e7674 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc @@ -17,27 +17,13 @@ #include "cast/common/certificate/cast_cert_validator_internal.h" #include "cast/common/certificate/cast_crl.h" +#include "cast/common/certificate/cast_trust_store.h" +#include "util/logging.h" +namespace openscreen { namespace cast { -namespace certificate { namespace { -using CastCertError = openscreen::Error::Code; - -// ------------------------------------------------------------------------- -// Cast trust anchors. -// ------------------------------------------------------------------------- - -// There are two trusted roots for Cast certificate chains: -// -// (1) CN=Cast Root CA (kCastRootCaDer) -// (2) CN=Eureka Root CA (kEurekaRootCaDer) -// -// These constants are defined by the files included next: - -#include "cast/common/certificate/cast_root_ca_cert_der-inc.h" -#include "cast/common/certificate/eureka_root_ca_der-inc.h" - // Returns the OID for the Audio-Only Cast policy // (1.3.6.1.4.1.11129.2.5.2) in DER form. const ConstDataSpan& AudioOnlyPolicyOid() { @@ -49,8 +35,8 @@ const ConstDataSpan& AudioOnlyPolicyOid() { class CertVerificationContextImpl final : public CertVerificationContext { public: - CertVerificationContextImpl(bssl::UniquePtr<EVP_PKEY>&& cert, - std::string&& common_name) + CertVerificationContextImpl(bssl::UniquePtr<EVP_PKEY> cert, + std::string common_name) : public_key_{std::move(cert)}, common_name_(std::move(common_name)) {} ~CertVerificationContextImpl() override = default; @@ -143,54 +129,31 @@ CastDeviceCertPolicy GetAudioPolicy(const std::vector<X509*>& path) { } // namespace -class CastTrustStore { - public: - // Singleton for the Cast trust store for legacy networkingPrivate use. - static CastTrustStore* GetInstance() { - static CastTrustStore* store = new CastTrustStore(); - return store; - } - - CastTrustStore() { - trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer)); - trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer)); - } - ~CastTrustStore() = default; - - TrustStore* trust_store() { return &trust_store_; } - - private: - TrustStore trust_store_; - OSP_DISALLOW_COPY_AND_ASSIGN(CastTrustStore); -}; - -openscreen::Error VerifyDeviceCert( - const std::vector<std::string>& der_certs, - const DateTime& time, - std::unique_ptr<CertVerificationContext>* context, - CastDeviceCertPolicy* policy, - const CastCRL* crl, - CRLPolicy crl_policy, - TrustStore* trust_store) { +Error VerifyDeviceCert(const std::vector<std::string>& der_certs, + const DateTime& time, + std::unique_ptr<CertVerificationContext>* context, + CastDeviceCertPolicy* policy, + const CastCRL* crl, + CRLPolicy crl_policy, + TrustStore* trust_store) { if (!trust_store) { trust_store = CastTrustStore::GetInstance()->trust_store(); } // Fail early if CRL is required but not provided. if (!crl && crl_policy == CRLPolicy::kCrlRequired) { - return CastCertError::kErrCrlInvalid; + return Error::Code::kErrCrlInvalid; } CertificatePathResult result_path = {}; - openscreen::Error error = - FindCertificatePath(der_certs, time, &result_path, trust_store); + Error error = FindCertificatePath(der_certs, time, &result_path, trust_store); if (!error.ok()) { return error; } if (crl_policy == CRLPolicy::kCrlRequired && !crl->CheckRevocation(result_path.path, time)) { - return CastCertError::kErrCertsRevoked; + return Error::Code::kErrCertsRevoked; } *policy = GetAudioPolicy(result_path.path); @@ -203,7 +166,7 @@ openscreen::Error VerifyDeviceCert( int len = X509_NAME_get_text_by_NID(target_subject, NID_commonName, &common_name[0], common_name.size()); if (len == 0) { - return CastCertError::kErrCertsRestrictions; + return Error::Code::kErrCertsRestrictions; } common_name.resize(len); @@ -211,8 +174,8 @@ openscreen::Error VerifyDeviceCert( bssl::UniquePtr<EVP_PKEY>{X509_get_pubkey(result_path.target_cert.get())}, std::move(common_name))); - return CastCertError::kNone; + return Error::Code::kNone; } -} // namespace certificate } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h index 36f5eeff356..c20e42db69a 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h @@ -13,8 +13,8 @@ #include "platform/base/error.h" #include "platform/base/macros.h" +namespace openscreen { namespace cast { -namespace certificate { class CastCRL; @@ -70,7 +70,7 @@ class CertVerificationContext { OSP_DISALLOW_COPY_AND_ASSIGN(CertVerificationContext); }; -// Verifies a cast device certficate given a chain of DER-encoded certificates. +// Verifies a cast device certificate given a chain of DER-encoded certificates. // // Inputs: // @@ -95,16 +95,15 @@ class CertVerificationContext { // // Outputs: // -// Returns openscreen::Error::Code::kNone on success. Otherwise, the -// corresponding openscreen::Error::Code. On success, the output parameters are -// filled with more details: +// Returns Error::Code::kNone on success. Otherwise, the corresponding +// Error::Code. On success, the output parameters are filled with more details: // // * |context| is filled with an object that can be used to verify signatures // using the device certificate's public key, as well as to extract other // properties from the device certificate (Common Name). // * |policy| is filled with an indication of the device certificate's policy // (i.e. is it for audio-only devices or is it unrestricted?) -[[nodiscard]] openscreen::Error VerifyDeviceCert( +[[nodiscard]] Error VerifyDeviceCert( const std::vector<std::string>& der_certs, const DateTime& time, std::unique_ptr<CertVerificationContext>* context, @@ -113,7 +112,7 @@ class CertVerificationContext { CRLPolicy crl_policy, TrustStore* trust_store = nullptr); -} // namespace certificate } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc index 331e71259a2..7e43e02c1fa 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc @@ -15,14 +15,12 @@ #include "cast/common/certificate/types.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace certificate { namespace { constexpr static int32_t kMinRsaModulusLengthBits = 2048; -using CastCertError = openscreen::Error::Code; - // Stores intermediate state while attempting to find a valid certificate chain // from a set of trusted certificates to a target certificate. Together, a // sequence of these forms a certificate chain to be verified as well as a stack @@ -63,16 +61,16 @@ uint8_t ParseAsn1TimeDoubleDigit(ASN1_GENERALIZEDTIME* time, int index) { return (time->data[index] - '0') * 10 + (time->data[index + 1] - '0'); } -CastCertError VerifyCertTime(X509* cert, const DateTime& time) { +Error::Code VerifyCertTime(X509* cert, const DateTime& time) { DateTime not_before; DateTime not_after; if (!GetCertValidTimeRange(cert, ¬_before, ¬_after)) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } if ((time < not_before) || (not_after < time)) { - return CastCertError::kErrCertsDateInvalid; + return Error::Code::kErrCertsDateInvalid; } - return CastCertError::kNone; + return Error::Code::kNone; } bool VerifyPublicKeyLength(EVP_PKEY* public_key) { @@ -94,27 +92,27 @@ bssl::UniquePtr<ASN1_BIT_STRING> GetKeyUsage(X509* cert) { return bssl::UniquePtr<ASN1_BIT_STRING>{key_usage_bit_string}; } -CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, - uint32_t step_index, - const DateTime& time) { +Error::Code VerifyCertificateChain(const std::vector<CertPathStep>& path, + uint32_t step_index, + const DateTime& time) { // Default max path length is the number of intermediate certificates. int max_pathlen = path.size() - 2; std::vector<NAME_CONSTRAINTS*> path_name_constraints; - CastCertError error = CastCertError::kNone; + Error::Code error = Error::Code::kNone; uint32_t i = step_index; for (; i < path.size() - 1; ++i) { X509* subject = path[i + 1].cert; X509* issuer = path[i].cert; bool is_root = (i == step_index); if (!is_root) { - if ((error = VerifyCertTime(issuer, time)) != CastCertError::kNone) { + if ((error = VerifyCertTime(issuer, time)) != Error::Code::kNone) { return error; } if (X509_NAME_cmp(X509_get_subject_name(issuer), X509_get_issuer_name(issuer)) != 0) { if (max_pathlen == 0) { - return CastCertError::kErrCertsPathlen; + return Error::Code::kErrCertsPathlen; } --max_pathlen; } else { @@ -129,7 +127,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, const int bit = ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kKeyCertSign); if (bit == 0) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } @@ -138,7 +136,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, const int basic_constraints_index = X509_get_ext_by_NID(issuer, NID_basic_constraints, -1); if (basic_constraints_index == -1) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } X509_EXTENSION* const basic_constraints_extension = X509_get_ext(issuer, basic_constraints_index); @@ -147,16 +145,16 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, X509V3_EXT_d2i(basic_constraints_extension))}; if (!basic_constraints || !basic_constraints->ca) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } if (basic_constraints->pathlen) { if (basic_constraints->pathlen->length != 1) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } else { const int pathlen = *basic_constraints->pathlen->data; if (pathlen < 0) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } if (pathlen < max_pathlen) { max_pathlen = pathlen; @@ -165,12 +163,12 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, } if (X509_ALGOR_cmp(issuer->sig_alg, issuer->cert_info->signature) != 0) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(issuer)}; if (!VerifyPublicKeyLength(public_key.get())) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } // NOTE: (!self-issued || target) -> verify name constraints. Target case @@ -179,7 +177,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, if (!is_self_issued) { for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) { if (NAME_CONSTRAINTS_check(subject, name_constraints) != X509_V_OK) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } } @@ -195,7 +193,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, issuer->nc = nc; path_name_constraints.push_back(nc); } else { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } } @@ -220,12 +218,12 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, ((OBJ_cmp(policy_mapping->issuerDomainPolicy, any_policy) == 0) || (OBJ_cmp(policy_mapping->subjectDomainPolicy, any_policy) == 0)); if (either_matches) { - error = CastCertError::kErrCertsVerifyGeneric; + error = Error::Code::kErrCertsVerifyGeneric; break; } } sk_POLICY_MAPPING_free(policy_mappings); - if (error != CastCertError::kNone) { + if (error != Error::Code::kNone) { return error; } } @@ -238,7 +236,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, const int nid = OBJ_obj2nid(extension->object); if (nid != NID_name_constraints && nid != NID_basic_constraints && nid != NID_key_usage) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } } @@ -259,7 +257,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, digest = EVP_sha512(); break; default: - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } if (!VerifySignedData( digest, public_key.get(), @@ -267,14 +265,14 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, static_cast<uint32_t>(subject->cert_info->enc.len)}, {subject->signature->data, static_cast<uint32_t>(subject->signature->length)})) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } // NOTE: Other half of ((!self-issued || target) -> check name constraints). for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) { if (NAME_CONSTRAINTS_check(path.back().cert, name_constraints) != X509_V_OK) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } } return error; @@ -370,12 +368,12 @@ bool VerifySignedData(const EVP_MD* digest, data.data, data.length) == 1); } -openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, - const DateTime& time, - CertificatePathResult* result_path, - TrustStore* trust_store) { +Error FindCertificatePath(const std::vector<std::string>& der_certs, + const DateTime& time, + CertificatePathResult* result_path, + TrustStore* trust_store) { if (der_certs.empty()) { - return CastCertError::kErrCertsMissing; + return Error::Code::kErrCertsMissing; } bssl::UniquePtr<X509>& target_cert = result_path->target_cert; @@ -383,36 +381,36 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->intermediate_certs; target_cert.reset(ParseX509Der(der_certs[0])); if (!target_cert) { - return CastCertError::kErrCertsParse; + return Error::Code::kErrCertsParse; } for (size_t i = 1; i < der_certs.size(); ++i) { intermediate_certs.emplace_back(ParseX509Der(der_certs[i])); if (!intermediate_certs.back()) { - return CastCertError::kErrCertsParse; + return Error::Code::kErrCertsParse; } } // Basic checks on the target certificate. - CastCertError error = VerifyCertTime(target_cert.get(), time); - if (error != CastCertError::kNone) { + Error::Code error = VerifyCertTime(target_cert.get(), time); + if (error != Error::Code::kNone) { return error; } bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())}; if (!VerifyPublicKeyLength(public_key.get())) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } if (X509_ALGOR_cmp(target_cert.get()->sig_alg, target_cert.get()->cert_info->signature) != 0) { - return CastCertError::kErrCertsVerifyGeneric; + return Error::Code::kErrCertsVerifyGeneric; } bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get()); if (!key_usage) { - return CastCertError::kErrCertsRestrictions; + return Error::Code::kErrCertsRestrictions; } int bit = ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature); if (bit == 0) { - return CastCertError::kErrCertsRestrictions; + return Error::Code::kErrCertsRestrictions; } X509* path_head = target_cert.get(); @@ -442,7 +440,7 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, // returned is whatever the last error was from the last path tried. uint32_t trust_store_index = 0; uint32_t intermediate_cert_index = 0; - CastCertError last_error = CastCertError::kNone; + Error::Code last_error = Error::Code::kNone; for (;;) { X509_NAME* target_issuer_name = X509_get_issuer_name(path_head); @@ -486,8 +484,8 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, if (!next_issuer) { if (path_index == first_index) { // There are no more paths to try. Ensure an error is returned. - if (last_error == CastCertError::kNone) { - return CastCertError::kErrCertsVerifyGeneric; + if (last_error == Error::Code::kNone) { + return Error::Code::kErrCertsVerifyGeneric; } return last_error; } else { @@ -500,7 +498,7 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, if (path_cert_in_trust_store) { last_error = VerifyCertificateChain(path, path_index, time); - if (last_error != CastCertError::kNone) { + if (last_error != Error::Code::kNone) { CertPathStep& last_step = path[path_index++]; trust_store_index = last_step.trust_store_index; intermediate_cert_index = last_step.intermediate_cert_index; @@ -517,8 +515,8 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->path.push_back(path[i].cert); } - return CastCertError::kNone; + return Error::Code::kNone; } -} // namespace certificate } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h index 1c127e72948..f8424b6d1c0 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h @@ -11,8 +11,8 @@ #include "platform/base/error.h" +namespace openscreen { namespace cast { -namespace certificate { struct TrustStore { std::vector<bssl::UniquePtr<X509>> certs; @@ -26,6 +26,11 @@ bssl::UniquePtr<X509> MakeTrustAnchor(const uint8_t (&data)[N]) { return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, N)}; } +inline bssl::UniquePtr<X509> MakeTrustAnchor(const std::vector<uint8_t>& data) { + const uint8_t* dptr = data.data(); + return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, data.size())}; +} + struct ConstDataSpan; struct DateTime; @@ -47,12 +52,12 @@ struct CertificatePathResult { std::vector<X509*> path; }; -openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs, - const DateTime& time, - CertificatePathResult* result_path, - TrustStore* trust_store); +Error FindCertificatePath(const std::vector<std::string>& der_certs, + const DateTime& time, + CertificatePathResult* result_path, + TrustStore* trust_store); -} // namespace certificate } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_INTERNAL_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc index c29de9cc6b7..41700a507cf 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc @@ -8,16 +8,14 @@ #include <string.h> #include "cast/common/certificate/cast_cert_validator_internal.h" -#include "cast/common/certificate/test_helpers.h" +#include "cast/common/certificate/testing/test_helpers.h" #include "gtest/gtest.h" #include "openssl/pem.h" +namespace openscreen { namespace cast { -namespace certificate { namespace { -using CastCertError = openscreen::Error::Code; - enum TrustStoreDependency { // Uses the built-in trust store for Cast. This is how certificates are // verified in production. @@ -45,7 +43,7 @@ enum TrustStoreDependency { // * |optional_signed_data_file_name| - optional path to a PEM file containing // a valid signature generated by the device certificate. // -void RunTest(CastCertError expected_result, +void RunTest(Error::Code expected_result, const std::string& expected_common_name, CastDeviceCertPolicy expected_policy, const std::string& certs_file_name, @@ -82,12 +80,11 @@ void RunTest(CastCertError expected_result, std::unique_ptr<CertVerificationContext> context; CastDeviceCertPolicy policy; - openscreen::Error result = - VerifyDeviceCert(certs, time, &context, &policy, nullptr, - CRLPolicy::kCrlOptional, trust_store); + Error result = VerifyDeviceCert(certs, time, &context, &policy, nullptr, + CRLPolicy::kCrlOptional, trust_store); ASSERT_EQ(expected_result, result.code()); - if (expected_result != CastCertError::kNone) + if (expected_result != Error::Code::kNone) return; EXPECT_EQ(expected_policy, policy); @@ -166,7 +163,7 @@ DateTime MarchFirst2037() { // Chains to trust anchor: // Eureka Root CA (built-in trust store) TEST(VerifyCastDeviceCertTest, ChromecastGen1) { - RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A", + RunTest(Error::Code::kNone, "2ZZBG9 FA8FCA3EF91A", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/chromecast_gen1.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, @@ -181,7 +178,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen1) { // Chains to trust anchor: // Cast Root CA (built-in trust store) TEST(VerifyCastDeviceCertTest, ChromecastGen1Reissue) { - RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A", + RunTest(Error::Code::kNone, "2ZZBG9 FA8FCA3EF91A", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/chromecast_gen1_reissue.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, @@ -196,7 +193,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen1Reissue) { // Chains to trust anchor: // Cast Root CA (built-in trust store) TEST(VerifyCastDeviceCertTest, ChromecastGen2) { - RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35", + RunTest(Error::Code::kNone, "3ZZAK6 FA8FCA3F0D35", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/chromecast_gen2.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -211,7 +208,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2) { // Chains to trust anchor: // Cast Root CA (built-in trust store) TEST(VerifyCastDeviceCertTest, Fugu) { - RunTest(CastCertError::kNone, "-6394818897508095075", + RunTest(Error::Code::kNone, "-6394818897508095075", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/fugu.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -226,7 +223,7 @@ TEST(VerifyCastDeviceCertTest, Fugu) { // // This is invalid because it does not chain to a trust anchor. TEST(VerifyCastDeviceCertTest, Unchained) { - RunTest(CastCertError::kErrCertsVerifyGeneric, "", + RunTest(Error::Code::kErrCertsVerifyGeneric, "", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/unchained.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -243,7 +240,7 @@ TEST(VerifyCastDeviceCertTest, Unchained) { // trust anchors after all) it fails the test as it is not a *device // certificate*. TEST(VerifyCastDeviceCertTest, CastRootCa) { - RunTest(CastCertError::kErrCertsRestrictions, "", + RunTest(Error::Code::kErrCertsRestrictions, "", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/cast_root_ca.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -260,7 +257,7 @@ TEST(VerifyCastDeviceCertTest, CastRootCa) { // This device certificate has a policy that means it is valid only for audio // devices. TEST(VerifyCastDeviceCertTest, ChromecastAudio) { - RunTest(CastCertError::kNone, "4ZZDZJ FA8FCA7EFE3C", + RunTest(Error::Code::kNone, "4ZZDZJ FA8FCA7EFE3C", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/chromecast_audio.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -278,7 +275,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastAudio) { // This device certificate has a policy that means it is valid only for audio // devices. TEST(VerifyCastDeviceCertTest, MtkAudioDev) { - RunTest(CastCertError::kNone, "MediaTek Audio Dev Test", + RunTest(Error::Code::kNone, "MediaTek Audio Dev Test", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/mtk_audio_dev.pem", JanuaryFirst2015(), TRUST_STORE_BUILTIN, ""); @@ -292,7 +289,7 @@ TEST(VerifyCastDeviceCertTest, MtkAudioDev) { // Chains to trust anchor: // Cast Root CA (built-in trust store) TEST(VerifyCastDeviceCertTest, Vizio) { - RunTest(CastCertError::kNone, "9V0000VB FA8FCA784D01", + RunTest(Error::Code::kNone, "9V0000VB FA8FCA784D01", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/vizio.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -305,17 +302,17 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2InvalidTime) { // Control test - certificate should be valid at some time otherwise // this test is pointless. - RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35", + RunTest(Error::Code::kNone, "3ZZAK6 FA8FCA3F0D35", CastDeviceCertPolicy::kUnrestricted, kCertsFile, AprilFirst2016(), TRUST_STORE_BUILTIN, ""); // Use a time before notBefore. - RunTest(CastCertError::kErrCertsDateInvalid, "", + RunTest(Error::Code::kErrCertsDateInvalid, "", CastDeviceCertPolicy::kUnrestricted, kCertsFile, JanuaryFirst2015(), TRUST_STORE_BUILTIN, ""); // Use a time after notAfter. - RunTest(CastCertError::kErrCertsDateInvalid, "", + RunTest(Error::Code::kErrCertsDateInvalid, "", CastDeviceCertPolicy::kUnrestricted, kCertsFile, MarchFirst2037(), TRUST_STORE_BUILTIN, ""); } @@ -332,7 +329,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2InvalidTime) { // This device certificate has a policy that means it is valid only for audio // devices. TEST(VerifyCastDeviceCertTest, AudioRefDevTestChain3) { - RunTest(CastCertError::kNone, "Audio Reference Dev Test", + RunTest(Error::Code::kNone, "Audio Reference Dev Test", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/audio_ref_dev_test_chain_3.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, @@ -359,7 +356,7 @@ TEST(VerifyCastDeviceCertTest, AudioRefDevTestChain3) { // This device certificate has a policy that means it is valid only for audio // devices. TEST(VerifyCastDeviceCertTest, IntermediateSerialNumberTooLong) { - RunTest(CastCertError::kNone, "8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA", + RunTest(Error::Code::kNone, "8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA", CastDeviceCertPolicy::AUDIO_ONLY, "certificates/intermediate_serialnumber_toolong.pem", AprilFirst2016(), TRUST_STORE_BUILTIN, ""); @@ -378,8 +375,7 @@ TEST(VerifyCastDeviceCertTest, IntermediateSerialNumberTooLong) { TEST(VerifyCastDeviceCertTest, ExpiredTrustAnchor) { // The root certificate is only valid in 2015, so validating with a time in // 2016 means it is expired. - RunTest(CastCertError::kNone, "CastDevice", - CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "CastDevice", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/expired_root.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -399,7 +395,7 @@ TEST(VerifyCastDeviceCertTest, ExpiredTrustAnchor) { // Root (provided by test data; has pathlen=1 constraint) TEST(VerifyCastDeviceCertTest, ViolatesPathlenTrustAnchorConstraint) { // Test that the chain verification fails due to the pathlen constraint. - RunTest(CastCertError::kErrCertsPathlen, "Target", + RunTest(Error::Code::kErrCertsPathlen, "Target", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/violates_root_pathlen_constraint.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -411,7 +407,7 @@ TEST(VerifyCastDeviceCertTest, ViolatesPathlenTrustAnchorConstraint) { // Intermediate: policies={anyPolicy} // Leaf: policies={anyPolicy} TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAnypolicy) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_anypolicy.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -423,7 +419,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAnypolicy) { // Intermediate: policies={anyPolicy} // Leaf: policies={audioOnly} TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAudioonly) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_audioonly.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -435,7 +431,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAudioonly) { // Intermediate: policies={anyPolicy} // Leaf: policies={foo} TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafFoo) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_foo.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -446,7 +442,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafFoo) { // Intermediate: policies={anyPolicy} // Leaf: policies={} TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafNone) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_none.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -457,7 +453,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafNone) { // Intermediate: policies={audioOnly} // Leaf: policies={anyPolicy} TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAnypolicy) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_anypolicy.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -469,7 +465,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAnypolicy) { // Intermediate: policies={audioOnly} // Leaf: policies={audioOnly} TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAudioonly) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_audioonly.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -481,7 +477,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAudioonly) { // Intermediate: policies={audioOnly} // Leaf: policies={foo} TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafFoo) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_foo.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -492,7 +488,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafFoo) { // Intermediate: policies={audioOnly} // Leaf: policies={} TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafNone) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_none.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -503,7 +499,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafNone) { // Intermediate: policies={} // Leaf: policies={anyPolicy} TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAnypolicy) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_anypolicy.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -514,7 +510,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAnypolicy) { // Intermediate: policies={} // Leaf: policies={audioOnly} TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAudioonly) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_audioonly.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -525,7 +521,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAudioonly) { // Intermediate: policies={} // Leaf: policies={foo} TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafFoo) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_foo.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -536,7 +532,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafFoo) { // Intermediate: policies={} // Leaf: policies={} TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafNone) { - RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_none.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -545,7 +541,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafNone) { // 1024-bit RSA key. Verification should fail since the target's key is // too weak. TEST(VerifyCastDeviceCertTest, DeviceCertHas1024BitRsaKey) { - RunTest(CastCertError::kErrCertsVerifyGeneric, "RSA 1024 Device Cert", + RunTest(Error::Code::kErrCertsVerifyGeneric, "RSA 1024 Device Cert", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/rsa1024_device_cert.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); @@ -555,7 +551,7 @@ TEST(VerifyCastDeviceCertTest, DeviceCertHas1024BitRsaKey) { // 2048-bit RSA key, and then verifying signed data (both SHA1 and SHA256) // for it. TEST(VerifyCastDeviceCertTest, DeviceCertHas2048BitRsaKey) { - RunTest(CastCertError::kNone, "RSA 2048 Device Cert", + RunTest(Error::Code::kNone, "RSA 2048 Device Cert", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/rsa2048_device_cert.pem", AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, @@ -566,7 +562,7 @@ TEST(VerifyCastDeviceCertTest, DeviceCertHas2048BitRsaKey) { // nameConstraints extension but the leaf certificate is still permitted under // these constraints. TEST(VerifyCastDeviceCertTest, NameConstraintsObeyed) { - RunTest(CastCertError::kNone, "Device", CastDeviceCertPolicy::kUnrestricted, + RunTest(Error::Code::kNone, "Device", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/nc.pem", AprilFirst2020(), TRUST_STORE_FROM_TEST_FILE, ""); } @@ -575,12 +571,12 @@ TEST(VerifyCastDeviceCertTest, NameConstraintsObeyed) { // nameConstraints extension and the leaf certificate is not permitted under // these constraints. TEST(VerifyCastDeviceCertTest, NameConstraintsViolated) { - RunTest(CastCertError::kErrCertsVerifyGeneric, "Device", + RunTest(Error::Code::kErrCertsVerifyGeneric, "Device", CastDeviceCertPolicy::kUnrestricted, TEST_DATA_PREFIX "certificates/nc_fail.pem", AprilFirst2020(), TRUST_STORE_FROM_TEST_FILE, ""); } } // namespace -} // namespace certificate } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc index 5935e1e8e75..41f05cad4fb 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc @@ -11,13 +11,13 @@ #include "absl/strings/string_view.h" #include "cast/common/certificate/cast_cert_validator_internal.h" -#include "cast/common/certificate/proto/revocation.pb.h" #include "platform/base/macros.h" +#include "util/crypto/certificate_utils.h" #include "util/crypto/sha2.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace certificate { namespace { enum CrlVersion { @@ -76,7 +76,7 @@ bool VerifyCRL(const Crl& crl, TrustStore* trust_store, DateTime* overall_not_after) { CertificatePathResult result_path = {}; - openscreen::Error error = + Error error = FindCertificatePath({crl.signer_cert()}, time, &result_path, trust_store); if (!error.ok()) { return false; @@ -134,39 +134,12 @@ bool VerifyCRL(const Crl& crl, return true; } -std::string GetSpkiTlv(X509* cert) { - int len = i2d_X509_PUBKEY(cert->cert_info->key, nullptr); - if (len <= 0) { - return {}; - } - std::string x(len, 0); - uint8_t* data = reinterpret_cast<uint8_t*>(&x[0]); - if (!i2d_X509_PUBKEY(cert->cert_info->key, &data)) { - return {}; - } - size_t actual_size = data - reinterpret_cast<uint8_t*>(&x[0]); - OSP_DCHECK_EQ(actual_size, x.size()); - x.resize(actual_size); - return x; -} - -bool ParseDerUint64(ASN1_INTEGER* asn1int, uint64_t* result) { - if (asn1int->length > 8 || asn1int->length == 0) { - return false; - } - *result = 0; - for (int i = 0; i < asn1int->length; ++i) { - *result = (*result << 8) | asn1int->data[i]; - } - return true; -} - } // namespace CastCRL::CastCRL(const TbsCrl& tbs_crl, const DateTime& overall_not_after) { // Parse the validity information. - // Assume ConvertTimeSeconds will succeed. Successful call to VerifyCRL - // means that these calls were successful. + // Assume DateTimeFromSeconds will succeed. Successful call to VerifyCRL means + // that these calls were successful. DateTimeFromSeconds(tbs_crl.not_before_seconds(), ¬_before_); DateTimeFromSeconds(tbs_crl.not_after_seconds(), ¬_after_); if (overall_not_after < not_after_) { @@ -210,8 +183,7 @@ bool CastCRL::CheckRevocation(const std::vector<X509*>& trusted_chain, return false; } - openscreen::ErrorOr<std::string> spki_hash = - openscreen::SHA256HashString(spki_tlv); + ErrorOr<std::string> spki_hash = SHA256HashString(spki_tlv); if (spki_hash.is_error() || (revoked_hashes_.find(spki_hash.value()) != revoked_hashes_.end())) { return false; @@ -226,10 +198,12 @@ bool CastCRL::CheckRevocation(const std::vector<X509*>& trusted_chain, // Only Google generated device certificates will be revoked by range. // These will always be less than 64 bits in length. - if (!ParseDerUint64(subordinate->cert_info->serialNumber, - &serial_number)) { + ErrorOr<uint64_t> maybe_serial = + ParseDerUint64(subordinate->cert_info->serialNumber); + if (!maybe_serial) { continue; } + serial_number = maybe_serial.value(); for (const auto& revoked_serial : issuer_iter->second) { if (revoked_serial.first_serial <= serial_number && revoked_serial.last_serial >= serial_number) { @@ -273,5 +247,5 @@ std::unique_ptr<CastCRL> ParseAndVerifyCRL(const std::string& crl_proto, return nullptr; } -} // namespace certificate } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h index 51b8fa5169c..420aa38e14d 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h @@ -17,8 +17,15 @@ #include "cast/common/certificate/proto/revocation.pb.h" #include "platform/base/macros.h" +namespace openscreen { namespace cast { -namespace certificate { + +// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to +// openscreen::cast +using CrlBundle = ::cast::certificate::CrlBundle; +using Crl = ::cast::certificate::Crl; +using TbsCrl = ::cast::certificate::TbsCrl; +using SerialNumberRange = ::cast::certificate::SerialNumberRange; // This class represents the certificate revocation list information parsed from // the binary in a protobuf message. @@ -81,7 +88,7 @@ std::unique_ptr<CastCRL> ParseAndVerifyCRL(const std::string& crl_proto, const DateTime& time, TrustStore* trust_store = nullptr); -} // namespace certificate } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CERTIFICATE_CAST_CRL_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc index 4e3a94d8d8c..81c0030f854 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc @@ -7,15 +7,21 @@ #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/certificate/cast_cert_validator_internal.h" #include "cast/common/certificate/proto/test_suite.pb.h" -#include "cast/common/certificate/test_helpers.h" +#include "cast/common/certificate/testing/test_helpers.h" #include "gtest/gtest.h" +#include "testing/util/read_file.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace certificate { -namespace { -using CastCertError = openscreen::Error::Code; +// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to +// openscreen::cast +using DeviceCertTestSuite = ::cast::certificate::DeviceCertTestSuite; +using VerificationResult = ::cast::certificate::VerificationResult; +using DeviceCertTest = ::cast::certificate::DeviceCertTest; + +namespace { // Indicates the expected result of test step's verification. enum TestStepResult { @@ -31,10 +37,9 @@ bool TestVerifyCertificate(TestStepResult expected_result, TrustStore* cast_trust_store) { std::unique_ptr<CertVerificationContext> context; CastDeviceCertPolicy policy; - openscreen::Error result = - VerifyDeviceCert(der_certs, time, &context, &policy, nullptr, - CRLPolicy::kCrlOptional, cast_trust_store); - bool success = (result.code() == CastCertError::kNone) == + Error result = VerifyDeviceCert(der_certs, time, &context, &policy, nullptr, + CRLPolicy::kCrlOptional, cast_trust_store); + bool success = (result.code() == Error::Code::kNone) == (expected_result == kResultSuccess); EXPECT_TRUE(success); return success; @@ -60,7 +65,7 @@ bool TestVerifyCRL(TestStepResult expected_result, // The provided CRL is verified at |crl_time|. // If |crl_required| is set, then a valid Cast CRL must be provided. // Otherwise, a missing CRL is be ignored. -bool TestVerifyRevocation(CastCertError expected_result, +bool TestVerifyRevocation(Error::Code expected_result, const std::vector<std::string>& der_certs, const std::string& crl_bundle, const DateTime& crl_time, @@ -78,9 +83,8 @@ bool TestVerifyRevocation(CastCertError expected_result, CastDeviceCertPolicy policy; CRLPolicy crl_policy = crl_required ? CRLPolicy::kCrlRequired : CRLPolicy::kCrlOptional; - openscreen::Error result = - VerifyDeviceCert(der_certs, cert_time, &context, &policy, crl.get(), - crl_policy, cast_trust_store); + Error result = VerifyDeviceCert(der_certs, cert_time, &context, &policy, + crl.get(), crl_policy, cast_trust_store); EXPECT_EQ(expected_result, result.code()); return expected_result == result.code(); } @@ -118,47 +122,47 @@ bool RunTest(const DeviceCertTest& test_case) { std::string crl_bundle = test_case.crl_bundle(); switch (test_case.expected_result()) { - case PATH_VERIFICATION_FAILED: + case ::cast::certificate::PATH_VERIFICATION_FAILED: return TestVerifyCertificate(kResultFail, der_cert_path, cert_verification_time, cast_trust_store.get()); - case CRL_VERIFICATION_FAILED: + case ::cast::certificate::CRL_VERIFICATION_FAILED: return TestVerifyCRL(kResultFail, crl_bundle, crl_verification_time, crl_trust_store.get()); - case REVOCATION_CHECK_FAILED_WITHOUT_CRL: + case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL: return TestVerifyCertificate(kResultSuccess, der_cert_path, cert_verification_time, cast_trust_store.get()) && TestVerifyCRL(kResultFail, crl_bundle, crl_verification_time, crl_trust_store.get()) && TestVerifyRevocation( - CastCertError::kErrCrlInvalid, der_cert_path, crl_bundle, + Error::Code::kErrCrlInvalid, der_cert_path, crl_bundle, crl_verification_time, cert_verification_time, true, cast_trust_store.get(), crl_trust_store.get()); - case CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: // fallthrough - case REVOCATION_CHECK_FAILED: + case ::cast::certificate:: + CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: // fallthrough + case ::cast::certificate::REVOCATION_CHECK_FAILED: return TestVerifyCertificate(kResultSuccess, der_cert_path, cert_verification_time, cast_trust_store.get()) && TestVerifyCRL(kResultSuccess, crl_bundle, crl_verification_time, crl_trust_store.get()) && TestVerifyRevocation( - CastCertError::kErrCertsRevoked, der_cert_path, crl_bundle, + Error::Code::kErrCertsRevoked, der_cert_path, crl_bundle, crl_verification_time, cert_verification_time, true, cast_trust_store.get(), crl_trust_store.get()); - case SUCCESS: + case ::cast::certificate::SUCCESS: return (crl_bundle.empty() || TestVerifyCRL(kResultSuccess, crl_bundle, crl_verification_time, crl_trust_store.get())) && TestVerifyCertificate(kResultSuccess, der_cert_path, cert_verification_time, cast_trust_store.get()) && - TestVerifyRevocation(CastCertError::kNone, der_cert_path, - crl_bundle, crl_verification_time, - cert_verification_time, !crl_bundle.empty(), - cast_trust_store.get(), + TestVerifyRevocation(Error::Code::kNone, der_cert_path, crl_bundle, + crl_verification_time, cert_verification_time, + !crl_bundle.empty(), cast_trust_store.get(), crl_trust_store.get()); - case UNSPECIFIED: + case ::cast::certificate::UNSPECIFIED: return false; } return false; @@ -169,8 +173,7 @@ bool RunTest(const DeviceCertTest& test_case) { // To see the description of the test, execute the test. // These tests are generated by a test generator in google3. void RunTestSuite(const std::string& test_suite_file_name) { - std::string testsuite_raw = - testing::ReadEntireFileToString(test_suite_file_name); + std::string testsuite_raw = ReadEntireFileToString(test_suite_file_name); ASSERT_FALSE(testsuite_raw.empty()); DeviceCertTestSuite test_suite; ASSERT_TRUE(test_suite.ParseFromString(testsuite_raw)); @@ -191,5 +194,5 @@ TEST(CastCertificateTest, TestSuite1) { } } // namespace -} // namespace certificate } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc new file mode 100644 index 00000000000..8c9e5e24b63 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/cast_trust_store.h" + +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +// ------------------------------------------------------------------------- +// Cast trust anchors. +// ------------------------------------------------------------------------- + +// There are two trusted roots for Cast certificate chains: +// +// (1) CN=Cast Root CA (kCastRootCaDer) +// (2) CN=Eureka Root CA (kEurekaRootCaDer) +// +// These constants are defined by the files included next: + +#include "cast/common/certificate/cast_root_ca_cert_der-inc.h" +#include "cast/common/certificate/eureka_root_ca_der-inc.h" + +} // namespace + +// static +CastTrustStore* CastTrustStore::GetInstance() { + if (!store_) { + store_ = new CastTrustStore(); + } + return store_; +} + +// static +void CastTrustStore::ResetInstance() { + delete store_; + store_ = nullptr; +} + +// static +CastTrustStore* CastTrustStore::CreateInstanceForTest( + const std::vector<uint8_t>& trust_anchor_der) { + OSP_DCHECK(!store_); + store_ = new CastTrustStore(trust_anchor_der); + return store_; +} + +CastTrustStore::CastTrustStore() { + trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer)); + trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer)); +} + +CastTrustStore::CastTrustStore(const std::vector<uint8_t>& trust_anchor_der) { + trust_store_.certs.emplace_back(MakeTrustAnchor(trust_anchor_der)); +} + +CastTrustStore::~CastTrustStore() = default; + +// static +CastTrustStore* CastTrustStore::store_ = nullptr; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h new file mode 100644 index 00000000000..8aac9d3905b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h @@ -0,0 +1,39 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_ +#define CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_ + +#include <vector> + +#include "cast/common/certificate/cast_cert_validator_internal.h" + +namespace openscreen { +namespace cast { + +class CastTrustStore { + public: + static CastTrustStore* GetInstance(); + static void ResetInstance(); + + static CastTrustStore* CreateInstanceForTest( + const std::vector<uint8_t>& trust_anchor_der); + + CastTrustStore(); + CastTrustStore(const std::vector<uint8_t>& trust_anchor_der); + CastTrustStore(const CastTrustStore&) = delete; + ~CastTrustStore(); + CastTrustStore& operator=(const CastTrustStore&) = delete; + + TrustStore* trust_store() { return &trust_store_; } + + private: + static CastTrustStore* store_; + TrustStore trust_store_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc b/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc deleted file mode 100644 index 41aef74feef..00000000000 --- a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/common/certificate/test_helpers.h" - -#include <openssl/pem.h> -#include <stdio.h> -#include <string.h> - -#include "util/logging.h" - -namespace cast { -namespace certificate { -namespace testing { - -std::string ReadEntireFileToString(const std::string& filename) { - FILE* file = fopen(filename.c_str(), "r"); - if (file == nullptr) { - return {}; - } - fseek(file, 0, SEEK_END); - long file_size = ftell(file); - fseek(file, 0, SEEK_SET); - std::string contents(file_size, 0); - int bytes_read = 0; - while (bytes_read < file_size) { - size_t ret = fread(&contents[bytes_read], 1, file_size - bytes_read, file); - if (ret == 0 && ferror(file)) { - return {}; - } else { - bytes_read += ret; - } - } - fclose(file); - - return contents; -} - -std::vector<std::string> ReadCertificatesFromPemFile( - const std::string& filename) { - FILE* fp = fopen(filename.c_str(), "r"); - if (!fp) { - return {}; - } - std::vector<std::string> certs; -#define STRCMP_LITERAL(s, l) strncmp(s, l, sizeof(l)) - for (;;) { - char* name; - char* header; - unsigned char* data; - long length; - if (PEM_read(fp, &name, &header, &data, &length) == 1) { - if (STRCMP_LITERAL(name, "CERTIFICATE") == 0) { - certs.emplace_back((char*)data, length); - } - OPENSSL_free(name); - OPENSSL_free(header); - OPENSSL_free(data); - } else { - break; - } - } - fclose(fp); - return certs; -} - -SignatureTestData::SignatureTestData() - : message{nullptr, 0}, sha1{nullptr, 0}, sha256{nullptr, 0} {} - -SignatureTestData::~SignatureTestData() { - OPENSSL_free(const_cast<uint8_t*>(message.data)); - OPENSSL_free(const_cast<uint8_t*>(sha1.data)); - OPENSSL_free(const_cast<uint8_t*>(sha256.data)); -} - -SignatureTestData ReadSignatureTestData(const std::string& filename) { - FILE* fp = fopen(filename.c_str(), "r"); - OSP_DCHECK(fp); - SignatureTestData result = {}; - for (;;) { - char* name; - char* header; - unsigned char* data; - long length; - if (PEM_read(fp, &name, &header, &data, &length) == 1) { - if (strcmp(name, "MESSAGE") == 0) { - OSP_DCHECK(!result.message.data); - result.message.data = data; - result.message.length = length; - } else if (strcmp(name, "SIGNATURE SHA1") == 0) { - OSP_DCHECK(!result.sha1.data); - result.sha1.data = data; - result.sha1.length = length; - } else if (strcmp(name, "SIGNATURE SHA256") == 0) { - OSP_DCHECK(!result.sha256.data); - result.sha256.data = data; - result.sha256.length = length; - } else { - OPENSSL_free(data); - } - OPENSSL_free(name); - OPENSSL_free(header); - } else { - break; - } - } - OSP_DCHECK(result.message.data); - OSP_DCHECK(result.sha1.data); - OSP_DCHECK(result.sha256.data); - - return result; -} - -std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile( - const std::string& filename) { - std::unique_ptr<TrustStore> store = std::make_unique<TrustStore>(); - - std::vector<std::string> certs = - testing::ReadCertificatesFromPemFile(filename); - for (const auto& der_cert : certs) { - const uint8_t* data = (const uint8_t*)der_cert.data(); - store->certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size())); - } - return store; -} - -} // namespace testing -} // namespace certificate -} // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc new file mode 100644 index 00000000000..eb9bac2fb92 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc @@ -0,0 +1,130 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/testing/test_helpers.h" + +#include <openssl/bytestring.h> +#include <openssl/pem.h> +#include <openssl/rsa.h> +#include <stdio.h> +#include <string.h> + +#include "absl/strings/match.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace testing { + +std::vector<std::string> ReadCertificatesFromPemFile( + absl::string_view filename) { + FILE* fp = fopen(filename.data(), "r"); + if (!fp) { + return {}; + } + std::vector<std::string> certs; + char* name; + char* header; + unsigned char* data; + long length; + while (PEM_read(fp, &name, &header, &data, &length) == 1) { + if (absl::StartsWith(name, "CERTIFICATE")) { + certs.emplace_back((char*)data, length); + } + OPENSSL_free(name); + OPENSSL_free(header); + OPENSSL_free(data); + } + fclose(fp); + return certs; +} + +bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename) { + FILE* fp = fopen(filename.data(), "r"); + if (!fp) { + return nullptr; + } + bssl::UniquePtr<EVP_PKEY> pkey; + char* name; + char* header; + unsigned char* data; + long length; + while (PEM_read(fp, &name, &header, &data, &length) == 1) { + if (absl::StartsWith(name, "RSA PRIVATE KEY")) { + OSP_DCHECK(!pkey); + CBS cbs; + CBS_init(&cbs, data, length); + RSA* rsa = RSA_parse_private_key(&cbs); + if (rsa) { + pkey.reset(EVP_PKEY_new()); + EVP_PKEY_assign_RSA(pkey.get(), rsa); + } + } + OPENSSL_free(name); + OPENSSL_free(header); + OPENSSL_free(data); + } + fclose(fp); + return pkey; +} + +SignatureTestData::SignatureTestData() + : message{nullptr, 0}, sha1{nullptr, 0}, sha256{nullptr, 0} {} + +SignatureTestData::~SignatureTestData() { + OPENSSL_free(const_cast<uint8_t*>(message.data)); + OPENSSL_free(const_cast<uint8_t*>(sha1.data)); + OPENSSL_free(const_cast<uint8_t*>(sha256.data)); +} + +SignatureTestData ReadSignatureTestData(absl::string_view filename) { + FILE* fp = fopen(filename.data(), "r"); + OSP_DCHECK(fp); + SignatureTestData result = {}; + char* name; + char* header; + unsigned char* data; + long length; + while (PEM_read(fp, &name, &header, &data, &length) == 1) { + if (strcmp(name, "MESSAGE") == 0) { + OSP_DCHECK(!result.message.data); + result.message.data = data; + result.message.length = length; + } else if (strcmp(name, "SIGNATURE SHA1") == 0) { + OSP_DCHECK(!result.sha1.data); + result.sha1.data = data; + result.sha1.length = length; + } else if (strcmp(name, "SIGNATURE SHA256") == 0) { + OSP_DCHECK(!result.sha256.data); + result.sha256.data = data; + result.sha256.length = length; + } else { + OPENSSL_free(data); + } + OPENSSL_free(name); + OPENSSL_free(header); + } + OSP_DCHECK(result.message.data); + OSP_DCHECK(result.sha1.data); + OSP_DCHECK(result.sha256.data); + + return result; +} + +std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile( + absl::string_view filename) { + std::unique_ptr<TrustStore> store = std::make_unique<TrustStore>(); + + std::vector<std::string> certs = + testing::ReadCertificatesFromPemFile(filename); + for (const auto& der_cert : certs) { + const uint8_t* data = (const uint8_t*)der_cert.data(); + store->certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size())); + } + return store; +} + +} // namespace testing +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.h b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.h index ac3a136d65f..c1ff9a25f78 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.h @@ -2,22 +2,25 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_ -#define CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_ +#ifndef CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_ +#define CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_ + +#include <openssl/evp.h> #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "cast/common/certificate/cast_cert_validator_internal.h" #include "cast/common/certificate/types.h" +namespace openscreen { namespace cast { -namespace certificate { namespace testing { -std::string ReadEntireFileToString(const std::string& filename); std::vector<std::string> ReadCertificatesFromPemFile( - const std::string& filename); + absl::string_view filename); +bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename); class SignatureTestData { public: @@ -29,13 +32,13 @@ class SignatureTestData { ConstDataSpan sha256; }; -SignatureTestData ReadSignatureTestData(const std::string& filename); +SignatureTestData ReadSignatureTestData(absl::string_view filename); std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile( - const std::string& filename); + absl::string_view filename); } // namespace testing -} // namespace certificate } // namespace cast +} // namespace openscreen -#endif // CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_ +#endif // CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/types.cc b/chromium/third_party/openscreen/src/cast/common/certificate/types.cc index 297fbffca43..507a033ec1c 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/types.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/types.cc @@ -6,8 +6,8 @@ #include "util/logging.h" +namespace openscreen { namespace cast { -namespace certificate { bool operator<(const DateTime& a, const DateTime& b) { if (a.year < b.year) { @@ -66,5 +66,22 @@ bool DateTimeFromSeconds(uint64_t seconds, DateTime* time) { return true; } -} // namespace certificate +static_assert(sizeof(time_t) >= 4, "Can't avoid overflow with < 32-bits"); + +std::chrono::seconds DateTimeToSeconds(const DateTime& time) { + OSP_DCHECK_GE(time.month, 1); + OSP_DCHECK_GE(time.year, 1900); + // NOTE: Guard against overflow if time_t is 32-bit. + OSP_DCHECK(sizeof(time_t) >= 8 || time.year < 2038) << time.year; + struct tm tm = {}; + tm.tm_sec = time.second; + tm.tm_min = time.minute; + tm.tm_hour = time.hour; + tm.tm_mday = time.day; + tm.tm_mon = time.month - 1; + tm.tm_year = time.year - 1900; + return std::chrono::seconds(mktime(&tm)); +} + } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/types.h b/chromium/third_party/openscreen/src/cast/common/certificate/types.h index 62fe45ed5d4..c3cf7efe273 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/types.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/types.h @@ -7,8 +7,10 @@ #include <stdint.h> +#include <chrono> + +namespace openscreen { namespace cast { -namespace certificate { struct ConstDataSpan { const uint8_t* data; @@ -28,7 +30,10 @@ bool operator<(const DateTime& a, const DateTime& b); bool operator>(const DateTime& a, const DateTime& b); bool DateTimeFromSeconds(uint64_t seconds, DateTime* time); -} // namespace certificate +// |time| is assumed to be valid. +std::chrono::seconds DateTimeToSeconds(const DateTime& time); + } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CERTIFICATE_TYPES_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn deleted file mode 100644 index 4a086190a90..00000000000 --- a/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2019 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -source_set("channel") { - sources = [ - "cast_message_handler.h", - "cast_socket.cc", - "cast_socket.h", - "message_framer.cc", - "message_framer.h", - "message_util.h", - "virtual_connection.h", - "virtual_connection_manager.cc", - "virtual_connection_manager.h", - ] - - public_deps = [ - "../../../platform", - "../../../third_party/abseil", - ] - - deps = [ - "../../../util", - "proto", - ] -} - -source_set("test") { - testonly = true - sources = [ - "test/fake_cast_socket.h", - "test/mock_cast_message_handler.h", - ] - - public_deps = [ - ":channel", - "../../../platform:test", - "../../../third_party/googletest:gmock", - ] -} - -source_set("unittests") { - testonly = true - sources = [ - "cast_socket_unittest.cc", - "message_framer_unittest.cc", - "virtual_connection_manager_unittest.cc", - ] - - deps = [ - ":channel", - ":test", - "../../../platform", - "../../../platform:test", - "../../../third_party/googletest:gmock", - "../../../third_party/googletest:gtest", - "../../../util", - "../../common/certificate/proto:unittest_proto", - "proto", - ] -} diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h index 9754b0c80d8..cd0d13e690d 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h @@ -5,11 +5,12 @@ #ifndef CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_ #define CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_ +#include "cast/common/channel/proto/cast_channel.pb.h" + +namespace openscreen { namespace cast { -namespace channel { class CastSocket; -class CastMessage; class VirtualConnectionRouter; class CastMessageHandler { @@ -18,10 +19,10 @@ class CastMessageHandler { virtual void OnMessage(VirtualConnectionRouter* router, CastSocket* socket, - CastMessage&& message) = 0; + ::cast::channel::CastMessage message) = 0; }; -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc index 5b1c9408be9..ba7996584ac 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc @@ -4,29 +4,20 @@ #include "cast/common/channel/cast_socket.h" -#include <atomic> - #include "cast/common/channel/message_framer.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace channel { +using ::cast::channel::CastMessage; using message_serialization::DeserializeResult; -using openscreen::ErrorOr; -using openscreen::platform::TlsConnection; - -uint32_t GetNextSocketId() { - static std::atomic<uint32_t> id(1); - return id++; -} CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, - Client* client, - uint32_t socket_id) - : client_(client), - connection_(std::move(connection)), - socket_id_(socket_id) { + Client* client) + : connection_(std::move(connection)), + client_(client), + socket_id_(g_next_socket_id_++) { OSP_DCHECK(client); connection_->SetClient(this); } @@ -46,12 +37,9 @@ Error CastSocket::SendMessage(const CastMessage& message) { return out.error(); } - if (state_ == State::kBlocked) { - message_queue_.emplace_back(std::move(out.value())); - return Error::Code::kNone; + if (!connection_->Send(out.value().data(), out.value().size())) { + return Error::Code::kAgain; } - - connection_->Write(out.value().data(), out.value().size()); return Error::Code::kNone; } @@ -60,27 +48,20 @@ void CastSocket::SetClient(Client* client) { client_ = client; } -void CastSocket::OnWriteBlocked(TlsConnection* connection) { - if (state_ == State::kOpen) { - state_ = State::kBlocked; +std::array<uint8_t, 2> CastSocket::GetSanitizedIpAddress() { + IPEndpoint remote = connection_->GetRemoteEndpoint(); + std::array<uint8_t, 2> result; + uint8_t bytes[16]; + if (remote.address.IsV4()) { + remote.address.CopyToV4(bytes); + result[0] = bytes[2]; + result[1] = bytes[3]; + } else { + remote.address.CopyToV6(bytes); + result[0] = bytes[14]; + result[1] = bytes[15]; } -} - -void CastSocket::OnWriteUnblocked(TlsConnection* connection) { - if (state_ != State::kBlocked) { - return; - } - state_ = State::kOpen; - - // Attempt to write all messages that have been queued-up while the socket was - // blocked. Stop if the socket becomes blocked again, or an error occurs. - auto it = message_queue_.begin(); - for (const auto end = message_queue_.end(); - it != end && state_ == State::kOpen; ++it) { - // The following Write() could transition |state_| to kBlocked or kError. - connection_->Write(it->data(), it->size()); - } - message_queue_.erase(message_queue_.begin(), it); + return result; } void CastSocket::OnError(TlsConnection* connection, Error error) { @@ -101,5 +82,7 @@ void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) { client_->OnMessage(this, std::move(message_or_error.value().message)); } -} // namespace channel +int CastSocket::g_next_socket_id_ = 1; + } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h index 6bd099b2f85..550aa76ed21 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h @@ -5,19 +5,15 @@ #ifndef CAST_COMMON_CHANNEL_CAST_SOCKET_H_ #define CAST_COMMON_CHANNEL_CAST_SOCKET_H_ +#include <array> +#include <memory> #include <vector> +#include "cast/common/channel/proto/cast_channel.pb.h" #include "platform/api/tls_connection.h" +namespace openscreen { namespace cast { -namespace channel { - -using openscreen::Error; -using TlsConnection = openscreen::platform::TlsConnection; - -class CastMessage; - -uint32_t GetNextSocketId(); // Represents a simple message-oriented socket for communicating with the Cast // V2 protocol. It isn't thread-safe, so it should only be used on the same @@ -31,46 +27,49 @@ class CastSocket : public TlsConnection::Client { // Called when a terminal error on |socket| has occurred. virtual void OnError(CastSocket* socket, Error error) = 0; - virtual void OnMessage(CastSocket* socket, CastMessage message) = 0; + virtual void OnMessage(CastSocket* socket, + ::cast::channel::CastMessage message) = 0; }; - CastSocket(std::unique_ptr<TlsConnection> connection, - Client* client, - uint32_t socket_id); + CastSocket(std::unique_ptr<TlsConnection> connection, Client* client); ~CastSocket(); // Sends |message| immediately unless the underlying TLS connection is // write-blocked, in which case |message| will be queued. An error will be // returned if |message| cannot be serialized for any reason, even while // write-blocked. - Error SendMessage(const CastMessage& message); + [[nodiscard]] Error SendMessage(const ::cast::channel::CastMessage& message); void SetClient(Client* client); - uint32_t socket_id() const { return socket_id_; } + std::array<uint8_t, 2> GetSanitizedIpAddress(); + + int socket_id() const { return socket_id_; } + + void set_audio_only(bool audio_only) { audio_only_ = audio_only; } + bool audio_only() const { return audio_only_; } // TlsConnection::Client overrides. - void OnWriteBlocked(TlsConnection* connection) override; - void OnWriteUnblocked(TlsConnection* connection) override; void OnError(TlsConnection* connection, Error error) override; void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override; private: - enum class State { - kOpen, - kBlocked, - kError, + enum class State : bool { + kOpen = true, + kError = false, }; - Client* client_; // May never be null. + static int g_next_socket_id_; + const std::unique_ptr<TlsConnection> connection_; + Client* client_; // May never be null. + const int socket_id_; + bool audio_only_ = false; std::vector<uint8_t> read_buffer_; - const uint32_t socket_id_; State state_ = State::kOpen; - std::vector<std::vector<uint8_t>> message_queue_; }; -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_CAST_SOCKET_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc index 35b157422e9..44776384305 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc @@ -6,18 +6,20 @@ #include "cast/common/channel/message_framer.h" #include "cast/common/channel/proto/cast_channel.pb.h" -#include "cast/common/channel/test/fake_cast_socket.h" +#include "cast/common/channel/testing/fake_cast_socket.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace channel { + +using ::cast::channel::CastMessage; + namespace { using ::testing::_; using ::testing::Invoke; - -using openscreen::ErrorOr; +using ::testing::Return; class CastSocketTest : public ::testing::Test { public: @@ -47,16 +49,36 @@ class CastSocketTest : public ::testing::Test { } // namespace TEST_F(CastSocketTest, SendMessage) { - EXPECT_CALL(connection(), Write(_, _)) + EXPECT_CALL(connection(), Send(_, _)) .WillOnce(Invoke([this](const void* data, size_t len) { EXPECT_EQ( frame_serial_, std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), reinterpret_cast<const uint8_t*>(data) + len)); + return true; })); ASSERT_TRUE(socket().SendMessage(message_).ok()); } +TEST_F(CastSocketTest, SendMessageEventuallyBlocks) { + EXPECT_CALL(connection(), Send(_, _)) + .Times(3) + .WillRepeatedly(Invoke([this](const void* data, size_t len) { + EXPECT_EQ( + frame_serial_, + std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + return true; + })) + .RetiresOnSaturation(); + ASSERT_TRUE(socket().SendMessage(message_).ok()); + ASSERT_TRUE(socket().SendMessage(message_).ok()); + ASSERT_TRUE(socket().SendMessage(message_).ok()); + + EXPECT_CALL(connection(), Send(_, _)).WillOnce(Return(false)); + ASSERT_EQ(socket().SendMessage(message_).code(), Error::Code::kAgain); +} + TEST_F(CastSocketTest, ReadCompleteMessage) { const uint8_t* data = frame_serial_.data(); EXPECT_CALL(mock_client(), OnMessage(_, _)) @@ -99,43 +121,19 @@ TEST_F(CastSocketTest, ReadChunkedMessage) { data + double_message.size())); } -TEST_F(CastSocketTest, SendMessageWhileBlocked) { - connection().OnWriteBlocked(); - EXPECT_CALL(connection(), Write(_, _)).Times(0); - ASSERT_TRUE(socket().SendMessage(message_).ok()); - - EXPECT_CALL(connection(), Write(_, _)) - .WillOnce(Invoke([this](const void* data, size_t len) { - EXPECT_EQ( - frame_serial_, - std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), - reinterpret_cast<const uint8_t*>(data) + len)); - })); - connection().OnWriteUnblocked(); - - EXPECT_CALL(connection(), Write(_, _)).Times(0); - connection().OnWriteBlocked(); - connection().OnWriteUnblocked(); -} - -TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) { - connection().OnWriteBlocked(); - EXPECT_CALL(connection(), Write(_, _)).Times(0); - ASSERT_TRUE(socket().SendMessage(message_).ok()); - - EXPECT_CALL(connection(), Write(_, _)) - .WillOnce(Invoke([this](const void* data, size_t len) { - EXPECT_EQ( - frame_serial_, - std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data), - reinterpret_cast<const uint8_t*>(data) + len)); - connection().OnError(Error::Code::kUnknownError); - })); - connection().OnWriteUnblocked(); - - EXPECT_CALL(connection(), Write(_, _)).Times(0); - ASSERT_FALSE(socket().SendMessage(message_).ok()); +TEST_F(CastSocketTest, SanitizedAddress) { + std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress(); + EXPECT_EQ(result1[0], 1u); + EXPECT_EQ(result1[1], 9u); + + FakeCastSocket v6_socket(IPEndpoint{{1, 2, 3, 4}, 1025}, + IPEndpoint{{0x1819, 0x1a1b, 0x1c1d, 0x1e1f, 0x207b, + 0x7c7d, 0x7e7f, 0x8081}, + 4321}); + std::array<uint8_t, 2> result2 = v6_socket.socket.GetSanitizedIpAddress(); + EXPECT_EQ(result2[0], 128); + EXPECT_EQ(result2[1], 129); } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc new file mode 100644 index 00000000000..40b2b84038b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc @@ -0,0 +1,259 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/connection_namespace_handler.h" + +#include <type_traits> + +#include "absl/types/optional.h" +#include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/virtual_connection.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "util/json/json_serialization.h" +#include "util/json/json_value.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +using ::cast::channel::CastMessage; +using ::cast::channel::CastMessage_PayloadType; + +namespace { + +bool IsValidProtocolVersion(int version) { + return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version); +} + +absl::optional<int> FindMaxProtocolVersion(const Json::Value* version, + const Json::Value* version_list) { + using ArrayIndex = Json::Value::ArrayIndex; + static_assert(std::is_integral<ArrayIndex>::value, + "Assuming ArrayIndex is integral"); + absl::optional<int> max_version; + if (version_list && version_list->isArray()) { + max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0; + for (auto it = version_list->begin(), end = version_list->end(); it != end; + ++it) { + if (it->isInt()) { + int version_int = it->asInt(); + if (IsValidProtocolVersion(version_int) && version_int > *max_version) { + max_version = version_int; + } + } + } + } + if (version && version->isInt()) { + int version_int = version->asInt(); + if (IsValidProtocolVersion(version_int)) { + if (!max_version) { + max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0; + } + if (version_int > max_version) { + max_version = version_int; + } + } + } + return max_version; +} + +VirtualConnection::CloseReason GetCloseReason( + const Json::Value& parsed_message) { + VirtualConnection::CloseReason reason = + VirtualConnection::CloseReason::kClosedByPeer; + absl::optional<int> reason_code = MaybeGetInt( + parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode)); + if (reason_code) { + int code = reason_code.value(); + if (code >= VirtualConnection::CloseReason::kFirstReason && + code <= VirtualConnection::CloseReason::kLastReason) { + reason = static_cast<VirtualConnection::CloseReason>(code); + } + } + return reason; +} + +} // namespace + +ConnectionNamespaceHandler::ConnectionNamespaceHandler( + VirtualConnectionManager* vc_manager, + VirtualConnectionPolicy* vc_policy) + : vc_manager_(vc_manager), vc_policy_(vc_policy) { + OSP_DCHECK(vc_manager); + OSP_DCHECK(vc_policy); +} + +ConnectionNamespaceHandler::~ConnectionNamespaceHandler() = default; + +void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + CastMessage message) { + if (message.payload_type() != + CastMessage_PayloadType::CastMessage_PayloadType_STRING) { + return; + } + ErrorOr<Json::Value> result = json::Parse(message.payload_utf8()); + if (result.is_error()) { + return; + } + + Json::Value& value = result.value(); + if (!value.isObject()) { + return; + } + + absl::optional<absl::string_view> type = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); + if (!type) { + // TODO(btolsch): Some of these paths should have error reporting. One + // possibility is to pass errors back through |router| so higher-level code + // can decide whether to show an error to the user, stop talking to a + // particular device, etc. + return; + } + + absl::string_view type_str = type.value(); + if (type_str == kMessageTypeConnect) { + HandleConnect(router, socket, std::move(message), std::move(value)); + } else if (type_str == kMessageTypeClose) { + HandleClose(router, socket, std::move(message), std::move(value)); + } else { + // NOTE: Unknown message type so ignore it. + // TODO(btolsch): Should be included in future error reporting. + } +} + +void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, + CastSocket* socket, + CastMessage message, + Json::Value parsed_message) { + if (message.destination_id() == kBroadcastId || + message.source_id() == kBroadcastId) { + return; + } + + VirtualConnection virtual_conn{std::move(message.destination_id()), + std::move(message.source_id()), + socket->socket_id()}; + if (!vc_policy_->IsConnectionAllowed(virtual_conn)) { + SendClose(router, std::move(virtual_conn)); + return; + } + + absl::optional<int> maybe_conn_type = MaybeGetInt( + parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType)); + VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong; + if (maybe_conn_type) { + int int_type = maybe_conn_type.value(); + if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) || + int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) { + SendClose(router, std::move(virtual_conn)); + return; + } + conn_type = static_cast<VirtualConnection::Type>(int_type); + } + + VirtualConnection::AssociatedData data; + + data.type = conn_type; + + absl::optional<absl::string_view> user_agent = MaybeGetString( + parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent)); + if (user_agent) { + data.user_agent = std::string(user_agent.value()); + } + + const Json::Value* sender_info_value = parsed_message.find( + JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo)); + if (!sender_info_value || !sender_info_value->isObject()) { + // TODO(btolsch): Should this be guessed from user agent? + OSP_DVLOG << "No sender info from protocol."; + } + + const Json::Value* version_value = parsed_message.find( + JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion)); + const Json::Value* version_list_value = parsed_message.find( + JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList)); + absl::optional<int> negotiated_version = + FindMaxProtocolVersion(version_value, version_list_value); + if (negotiated_version) { + data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>( + negotiated_version.value()); + } else { + data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0; + } + + data.ip_fragment = socket->GetSanitizedIpAddress(); + + OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", " + << virtual_conn.peer_id << ", " << virtual_conn.socket_id; + + // NOTE: Only send a response for senders that actually sent a version. This + // maintains compatibility with older senders that don't send a version and + // don't expect a response. + if (negotiated_version) { + SendConnectedResponse(router, virtual_conn, negotiated_version.value()); + } + + vc_manager_->AddConnection(std::move(virtual_conn), std::move(data)); +} + +void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router, + CastSocket* socket, + CastMessage message, + Json::Value parsed_message) { + VirtualConnection virtual_conn{std::move(message.destination_id()), + std::move(message.source_id()), + socket->socket_id()}; + if (!vc_manager_->GetConnectionData(virtual_conn)) { + return; + } + + VirtualConnection::CloseReason reason = GetCloseReason(parsed_message); + + OSP_DVLOG << "Connection closed (reason: " << reason + << "): " << virtual_conn.local_id << ", " << virtual_conn.peer_id + << ", " << virtual_conn.socket_id; + vc_manager_->RemoveConnection(virtual_conn, reason); +} + +void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router, + VirtualConnection virtual_conn) { + Json::Value close_message(Json::ValueType::objectValue); + close_message[kMessageKeyType] = kMessageTypeClose; + + ErrorOr<std::string> result = json::Stringify(close_message); + if (result.is_error()) { + return; + } + + router->SendMessage( + std::move(virtual_conn), + MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value()))); +} + +void ConnectionNamespaceHandler::SendConnectedResponse( + VirtualConnectionRouter* router, + const VirtualConnection& virtual_conn, + int max_protocol_version) { + Json::Value connected_message(Json::ValueType::objectValue); + connected_message[kMessageKeyType] = kMessageTypeConnected; + connected_message[kMessageKeyProtocolVersion] = + static_cast<int>(max_protocol_version); + + ErrorOr<std::string> result = json::Stringify(connected_message); + if (result.is_error()) { + return; + } + + router->SendMessage( + virtual_conn, + MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value()))); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h new file mode 100644 index 00000000000..5307e896893 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h @@ -0,0 +1,64 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_ +#define CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_ + +#include "cast/common/channel/cast_message_handler.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "util/json/json_serialization.h" + +namespace openscreen { +namespace cast { + +struct VirtualConnection; +class VirtualConnectionManager; +class VirtualConnectionRouter; + +// Handles CastMessages in the connection namespace by opening and closing +// VirtualConnections on the socket on which the messages were received. +class ConnectionNamespaceHandler final : public CastMessageHandler { + public: + class VirtualConnectionPolicy { + public: + virtual ~VirtualConnectionPolicy() = default; + + virtual bool IsConnectionAllowed( + const VirtualConnection& virtual_conn) const = 0; + }; + + // Both |vc_manager| and |vc_policy| should outlive this object. + ConnectionNamespaceHandler(VirtualConnectionManager* vc_manager, + VirtualConnectionPolicy* vc_policy); + ~ConnectionNamespaceHandler() override; + + // CastMessageHandler overrides. + void OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) override; + + private: + void HandleConnect(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message, + Json::Value parsed_message); + void HandleClose(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message, + Json::Value parsed_message); + + void SendClose(VirtualConnectionRouter* router, + VirtualConnection virtual_conn); + void SendConnectedResponse(VirtualConnectionRouter* router, + const VirtualConnection& virtual_conn, + int max_protocol_version); + + VirtualConnectionManager* const vc_manager_; + VirtualConnectionPolicy* const vc_policy_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc new file mode 100644 index 00000000000..0ff0763b093 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc @@ -0,0 +1,226 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/connection_namespace_handler.h" + +#include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_socket_error_handler.h" +#include "cast/common/channel/virtual_connection.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/json/json_serialization.h" +#include "util/json/json_value.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +using ::testing::_; +using ::testing::Invoke; +using ::testing::NiceMock; + +using ::cast::channel::CastMessage; +using ::cast::channel::CastMessage_ProtocolVersion; + +class MockVirtualConnectionPolicy + : public ConnectionNamespaceHandler::VirtualConnectionPolicy { + public: + ~MockVirtualConnectionPolicy() override = default; + + MOCK_METHOD(bool, + IsConnectionAllowed, + (const VirtualConnection& virtual_conn), + (const, override)); +}; + +CastMessage MakeVersionedConnectMessage( + const std::string& source_id, + const std::string& destination_id, + absl::optional<CastMessage_ProtocolVersion> version, + std::vector<CastMessage_ProtocolVersion> version_list) { + CastMessage connect_message = MakeConnectMessage(source_id, destination_id); + Json::Value message(Json::ValueType::objectValue); + message[kMessageKeyType] = kMessageTypeConnect; + if (version) { + message[kMessageKeyProtocolVersion] = version.value(); + } + if (!version_list.empty()) { + Json::Value list(Json::ValueType::arrayValue); + for (CastMessage_ProtocolVersion v : version_list) { + list.append(v); + } + message[kMessageKeyProtocolVersionList] = std::move(list); + } + ErrorOr<std::string> result = json::Stringify(message); + OSP_DCHECK(result); + connect_message.set_payload_utf8(std::move(result.value())); + return connect_message; +} + +void VerifyConnectionMessage(const CastMessage& message, + const std::string& source_id, + const std::string& destination_id) { + EXPECT_EQ(message.source_id(), source_id); + EXPECT_EQ(message.destination_id(), destination_id); + EXPECT_EQ(message.namespace_(), kConnectionNamespace); + ASSERT_EQ(message.payload_type(), + ::cast::channel::CastMessage_PayloadType_STRING); +} + +Json::Value ParseConnectionMessage(const CastMessage& message) { + ErrorOr<Json::Value> result = json::Parse(message.payload_utf8()); + OSP_CHECK(result) << message.payload_utf8(); + return result.value(); +} + +} // namespace + +class ConnectionNamespaceHandlerTest : public ::testing::Test { + public: + void SetUp() override { + socket_ = fake_cast_socket_pair_.socket.get(); + router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.socket)); + + ON_CALL(vc_policy_, IsConnectionAllowed(_)) + .WillByDefault( + Invoke([](const VirtualConnection& virtual_conn) { return true; })); + } + + protected: + void ExpectCloseMessage(MockCastSocketClient* mock_client, + const std::string& source_id, + const std::string& destination_id) { + EXPECT_CALL(*mock_client, OnMessage(_, _)) + .WillOnce(Invoke([&source_id, &destination_id](CastSocket* socket, + CastMessage message) { + VerifyConnectionMessage(message, source_id, destination_id); + Json::Value value = ParseConnectionMessage(message); + absl::optional<absl::string_view> type = MaybeGetString( + value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); + ASSERT_TRUE(type) << message.payload_utf8(); + EXPECT_EQ(type.value(), kMessageTypeClose) << message.payload_utf8(); + })); + } + + void ExpectConnectedMessage( + MockCastSocketClient* mock_client, + const std::string& source_id, + const std::string& destination_id, + absl::optional<CastMessage_ProtocolVersion> version = absl::nullopt) { + EXPECT_CALL(*mock_client, OnMessage(_, _)) + .WillOnce(Invoke([&source_id, &destination_id, version]( + CastSocket* socket, CastMessage message) { + VerifyConnectionMessage(message, source_id, destination_id); + Json::Value value = ParseConnectionMessage(message); + absl::optional<absl::string_view> type = MaybeGetString( + value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); + ASSERT_TRUE(type) << message.payload_utf8(); + EXPECT_EQ(type.value(), kMessageTypeConnected) + << message.payload_utf8(); + if (version) { + absl::optional<int> message_version = MaybeGetInt( + value, + JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion)); + ASSERT_TRUE(message_version) << message.payload_utf8(); + EXPECT_EQ(message_version.value(), version.value()); + } + })); + } + + FakeCastSocketPair fake_cast_socket_pair_; + MockSocketErrorHandler mock_error_handler_; + CastSocket* socket_; + + NiceMock<MockVirtualConnectionPolicy> vc_policy_; + VirtualConnectionManager vc_manager_; + VirtualConnectionRouter router_{&vc_manager_}; + ConnectionNamespaceHandler connection_namespace_handler_{&vc_manager_, + &vc_policy_}; + + const std::string sender_id_{"sender-5678"}; + const std::string receiver_id_{"receiver-3245"}; +}; + +TEST_F(ConnectionNamespaceHandlerTest, Connect) { + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); + + EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) + .Times(0); +} + +TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) { + EXPECT_CALL(vc_policy_, IsConnectionAllowed(_)) + .WillOnce( + Invoke([](const VirtualConnection& virtual_conn) { return false; })); + ExpectCloseMessage(&fake_cast_socket_pair_.mock_peer_client, receiver_id_, + sender_id_); + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); + EXPECT_FALSE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); +} + +TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) { + ExpectConnectedMessage( + &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_, + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2); + connection_namespace_handler_.OnMessage( + &router_, socket_, + MakeVersionedConnectMessage( + sender_id_, receiver_id_, + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {})); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); +} + +TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) { + ExpectConnectedMessage( + &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_, + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3); + connection_namespace_handler_.OnMessage( + &router_, socket_, + MakeVersionedConnectMessage( + sender_id_, receiver_id_, + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, + {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3, + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0})); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); +} + +TEST_F(ConnectionNamespaceHandlerTest, Close) { + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); + + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_)); + EXPECT_FALSE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); +} + +TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) { + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); + + connection_namespace_handler_.OnMessage( + &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_)); + EXPECT_TRUE(vc_manager_.GetConnectionData( + VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc index 2494a6ea67c..a8406539d3b 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc @@ -13,12 +13,10 @@ #include "util/big_endian.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace channel { namespace message_serialization { -using openscreen::Error; - namespace { static constexpr size_t kHeaderSize = sizeof(uint32_t); @@ -28,26 +26,26 @@ static constexpr size_t kMaxBodySize = 65536; } // namespace -ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message) { +ErrorOr<std::vector<uint8_t>> Serialize( + const ::cast::channel::CastMessage& message) { const size_t message_size = message.ByteSizeLong(); if (message_size > kMaxBodySize || message_size == 0) { return Error::Code::kCastV2InvalidMessage; } std::vector<uint8_t> out(message_size + kHeaderSize, 0); - openscreen::WriteBigEndian<uint32_t>(message_size, out.data()); + WriteBigEndian<uint32_t>(message_size, out.data()); if (!message.SerializeToArray(&out[kHeaderSize], message_size)) { return Error::Code::kCastV2InvalidMessage; } return out; } -ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) { +ErrorOr<DeserializeResult> TryDeserialize(absl::Span<const uint8_t> input) { if (input.size() < kHeaderSize) { return Error::Code::kInsufficientBuffer; } - const uint32_t message_size = - openscreen::ReadBigEndian<uint32_t>(input.data()); + const uint32_t message_size = ReadBigEndian<uint32_t>(input.data()); if (message_size > kMaxBodySize) { return Error::Code::kCastV2InvalidMessage; } @@ -67,5 +65,5 @@ ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) { } } // namespace message_serialization -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h index c092487cb1d..3de1cb7594d 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h @@ -15,18 +15,17 @@ #include "cast/common/channel/proto/cast_channel.pb.h" #include "platform/base/error.h" +namespace openscreen { namespace cast { -namespace channel { namespace message_serialization { -using openscreen::ErrorOr; - // Serializes |message_proto| into |message_data|. // Returns true if the message was serialized successfully, false otherwise. -ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message); +ErrorOr<std::vector<uint8_t>> Serialize( + const ::cast::channel::CastMessage& message); struct DeserializeResult { - CastMessage message; + ::cast::channel::CastMessage message; size_t length; }; @@ -34,10 +33,10 @@ struct DeserializeResult { // read. Returns a parsed CastMessage if a message was received in its // entirety, and an error otherwise. The result also contains the number of // bytes consumed from |input| when a parse succeeds. -ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input); +ErrorOr<DeserializeResult> TryDeserialize(absl::Span<const uint8_t> input); } // namespace message_serialization -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc new file mode 100644 index 00000000000..53678e748a0 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc @@ -0,0 +1,14 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cstddef> +#include <cstdint> + +#include "cast/common/channel/message_framer.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + openscreen::cast::message_serialization::TryDeserialize( + absl::Span<const uint8_t>(data, size)); + return 0; +} diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0 Binary files differnew file mode 100644 index 00000000000..41fd475902d --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0 Binary files differnew file mode 100644 index 00000000000..ab09cd27a30 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4 Binary files differnew file mode 100644 index 00000000000..fc53faf34ba --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1 Binary files differnew file mode 100644 index 00000000000..05f1e12bd3e --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2 Binary files differnew file mode 100644 index 00000000000..6f745b25c75 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto Binary files differnew file mode 100644 index 00000000000..7dd4315ed6a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c Binary files differnew file mode 100644 index 00000000000..445ded6c8b4 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6 Binary files differnew file mode 100644 index 00000000000..4e15a3d7d4f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7 Binary files differnew file mode 100644 index 00000000000..5dc9591749f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7 diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc index ae2ff33e698..e70459e67ec 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc @@ -14,11 +14,11 @@ #include "util/big_endian.h" #include "util/std_util.h" +namespace openscreen { namespace cast { -namespace channel { namespace message_serialization { -using openscreen::Error; +using ::cast::channel::CastMessage; namespace { @@ -149,5 +149,5 @@ TEST_F(CastFramerTest, TestUnparsableBodyProto) { } } // namespace message_serialization -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc new file mode 100644 index 00000000000..ee9a91f1309 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/message_util.h" + +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +using ::cast::channel::CastMessage; + +CastMessage MakeConnectionMessage(const std::string& source_id, + const std::string& destination_id) { + CastMessage connect_message; + connect_message.set_protocol_version(kDefaultOutgoingMessageVersion); + connect_message.set_source_id(source_id); + connect_message.set_destination_id(destination_id); + connect_message.set_namespace_(kConnectionNamespace); + return connect_message; +} + +} // namespace + +std::string ToString(AppAvailabilityResult availability) { + switch (availability) { + case AppAvailabilityResult::kAvailable: + return "Available"; + case AppAvailabilityResult::kUnavailable: + return "Unavailable"; + case AppAvailabilityResult::kUnknown: + return "Unknown"; + default: + OSP_NOTREACHED(); + return "bad value"; + } +} + +CastMessage MakeSimpleUTF8Message(const std::string& namespace_, + std::string payload) { + CastMessage message; + message.set_protocol_version(kDefaultOutgoingMessageVersion); + message.set_namespace_(namespace_); + message.set_payload_type(::cast::channel::CastMessage_PayloadType_STRING); + message.set_payload_utf8(std::move(payload)); + return message; +} + +CastMessage MakeConnectMessage(const std::string& source_id, + const std::string& destination_id) { + CastMessage connect_message = + MakeConnectionMessage(source_id, destination_id); + connect_message.set_payload_type( + ::cast::channel::CastMessage_PayloadType_STRING); + connect_message.set_payload_utf8(R"!({"type": "CONNECT"})!"); + return connect_message; +} + +CastMessage MakeCloseMessage(const std::string& source_id, + const std::string& destination_id) { + CastMessage close_message = MakeConnectionMessage(source_id, destination_id); + close_message.set_payload_type( + ::cast::channel::CastMessage_PayloadType_STRING); + close_message.set_payload_utf8(R"!({"type": "CLOSE"})!"); + return close_message; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_util.h b/chromium/third_party/openscreen/src/cast/common/channel/message_util.h index 990fd55c5f8..f1ba2ead412 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/message_util.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/message_util.h @@ -8,8 +8,8 @@ #include "absl/strings/string_view.h" #include "cast/common/channel/proto/cast_channel.pb.h" +namespace openscreen { namespace cast { -namespace channel { // Reserved message namespaces for internal messages. static constexpr char kCastInternalNamespacePrefix[] = @@ -32,7 +32,153 @@ static constexpr char kMediaNamespace[] = "urn:x-cast:com.google.cast.media"; static constexpr char kPlatformSenderId[] = "sender-0"; static constexpr char kPlatformReceiverId[] = "receiver-0"; -inline bool IsAuthMessage(const CastMessage& message) { +static constexpr char kBroadcastId[] = "*"; + +static constexpr ::cast::channel::CastMessage_ProtocolVersion + kDefaultOutgoingMessageVersion = + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0; + +// JSON message key strings. +static constexpr char kMessageKeyType[] = "type"; +static constexpr char kMessageKeyConnType[] = "connType"; +static constexpr char kMessageKeyUserAgent[] = "userAgent"; +static constexpr char kMessageKeySenderInfo[] = "senderInfo"; +static constexpr char kMessageKeyProtocolVersion[] = "protocolVersion"; +static constexpr char kMessageKeyProtocolVersionList[] = "protocolVersionList"; +static constexpr char kMessageKeyReasonCode[] = "reasonCode"; +static constexpr char kMessageKeyAppId[] = "appId"; +static constexpr char kMessageKeyRequestId[] = "requestId"; +static constexpr char kMessageKeyAvailability[] = "availability"; + +// JSON message field values. +static constexpr char kMessageTypeConnect[] = "CONNECT"; +static constexpr char kMessageTypeClose[] = "CLOSE"; +static constexpr char kMessageTypeConnected[] = "CONNECTED"; +static constexpr char kMessageValueAppAvailable[] = "APP_AVAILABLE"; +static constexpr char kMessageValueAppUnavailable[] = "APP_UNAVAILABLE"; + +// TODO(crbug.com/openscreen/111): Add validation that each message type is +// received on the correct namespace. This will probably involve creating a +// data structure for mapping between type and namespace. +enum class CastMessageType { + // Heartbeat messages. + kPing, + kPong, + + // RPC control/status messages used by Media Remoting. These occur at high + // frequency, up to dozens per second at times, and should not be logged. + kRpc, + + kGetAppAvailability, + kGetStatus, + + // Virtual connection request. + kConnect, + + // Close virtual connection. + kCloseConnection, + + // Application broadcast / precache. + kBroadcast, + + // Session launch request. + kLaunch, + + // Session stop request. + kStop, + + kReceiverStatus, + kMediaStatus, + + // Error from receiver. + kLaunchError, + + kOffer, + kAnswer, + kCapabilitiesResponse, + kStatusResponse, + + // The following values are part of the protocol but are not currently used. + kMultizoneStatus, + kInvalidPlayerState, + kLoadFailed, + kLoadCancelled, + kInvalidRequest, + kPresentation, + kGetCapabilities, + + kOther, // Add new types above |kOther|. + kMaxValue = kOther, +}; + +enum class AppAvailabilityResult { + kAvailable, + kUnavailable, + kUnknown, +}; + +std::string ToString(AppAvailabilityResult availability); + +// TODO(crbug.com/openscreen/111): When this and/or other enums need the +// string->enum mapping, import EnumTable from Chromium's +// //components/cast_channel/enum_table.h. +inline constexpr const char* CastMessageTypeToString(CastMessageType type) { + switch (type) { + case CastMessageType::kPing: + return "PING"; + case CastMessageType::kPong: + return "PONG"; + case CastMessageType::kRpc: + return "RPC"; + case CastMessageType::kGetAppAvailability: + return "GET_APP_AVAILABILITY"; + case CastMessageType::kGetStatus: + return "GET_STATUS"; + case CastMessageType::kConnect: + return "CONNECT"; + case CastMessageType::kCloseConnection: + return "CLOSE"; + case CastMessageType::kBroadcast: + return "APPLICATION_BROADCAST"; + case CastMessageType::kLaunch: + return "LAUNCH"; + case CastMessageType::kStop: + return "STOP"; + case CastMessageType::kReceiverStatus: + return "RECEIVER_STATUS"; + case CastMessageType::kMediaStatus: + return "MEDIA_STATUS"; + case CastMessageType::kLaunchError: + return "LAUNCH_ERROR"; + case CastMessageType::kOffer: + return "OFFER"; + case CastMessageType::kAnswer: + return "ANSWER"; + case CastMessageType::kCapabilitiesResponse: + return "CAPABILITIES_RESPONSE"; + case CastMessageType::kStatusResponse: + return "STATUS_RESPONSE"; + case CastMessageType::kMultizoneStatus: + return "MULTIZONE_STATUS"; + case CastMessageType::kInvalidPlayerState: + return "INVALID_PLAYER_STATE"; + case CastMessageType::kLoadFailed: + return "LOAD_FAILED"; + case CastMessageType::kLoadCancelled: + return "LOAD_CANCELLED"; + case CastMessageType::kInvalidRequest: + return "INVALID_REQUEST"; + case CastMessageType::kPresentation: + return "PRESENTATION"; + case CastMessageType::kGetCapabilities: + return "GET_CAPABILITIES"; + case CastMessageType::kOther: + default: + return "OTHER"; + } +} + +inline bool IsAuthMessage(const ::cast::channel::CastMessage& message) { return message.namespace_() == kAuthNamespace; } @@ -41,7 +187,18 @@ inline bool IsTransportNamespace(absl::string_view namespace_) { (namespace_.find_first_of(kTransportNamespacePrefix) == 0); } -} // namespace channel +::cast::channel::CastMessage MakeSimpleUTF8Message( + const std::string& namespace_, + std::string payload); + +::cast::channel::CastMessage MakeConnectMessage( + const std::string& source_id, + const std::string& destination_id); +::cast::channel::CastMessage MakeCloseMessage( + const std::string& source_id, + const std::string& destination_id); + } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc new file mode 100644 index 00000000000..041d8f3f0df --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc @@ -0,0 +1,35 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/namespace_router.h" + +#include "cast/common/channel/proto/cast_channel.pb.h" + +namespace openscreen { +namespace cast { + +NamespaceRouter::NamespaceRouter() = default; +NamespaceRouter::~NamespaceRouter() = default; + +void NamespaceRouter::AddNamespaceHandler(std::string namespace_, + CastMessageHandler* handler) { + handlers_.emplace(std::move(namespace_), handler); +} + +void NamespaceRouter::RemoveNamespaceHandler(const std::string& namespace_) { + handlers_.erase(namespace_); +} + +void NamespaceRouter::OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) { + const std::string& ns = message.namespace_(); + auto it = handlers_.find(ns); + if (it != handlers_.end()) { + it->second->OnMessage(router, socket, std::move(message)); + } +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h new file mode 100644 index 00000000000..0b6b581c0e1 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h @@ -0,0 +1,37 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_ +#define CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_ + +#include <map> +#include <string> + +#include "cast/common/channel/cast_message_handler.h" +#include "cast/common/channel/proto/cast_channel.pb.h" + +namespace openscreen { +namespace cast { + +class NamespaceRouter final : public CastMessageHandler { + public: + NamespaceRouter(); + ~NamespaceRouter() override; + + void AddNamespaceHandler(std::string namespace_, CastMessageHandler* handler); + void RemoveNamespaceHandler(const std::string& namespace_); + + // CastMessageHandler overrides. + void OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) override; + + private: + std::map<std::string /* namespace */, CastMessageHandler*> handlers_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc new file mode 100644 index 00000000000..96907c081c0 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc @@ -0,0 +1,98 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/channel/namespace_router.h" + +#include "cast/common/channel/cast_message_handler.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_cast_message_handler.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace openscreen { +namespace cast { +namespace { + +using ::cast::channel::CastMessage; +using ::testing::_; +using ::testing::Invoke; + +class NamespaceRouterTest : public ::testing::Test { + public: + protected: + CastSocket* socket() { return &fake_socket_.socket; } + + FakeCastSocket fake_socket_; + VirtualConnectionManager vc_manager_; + VirtualConnectionRouter vc_router_{&vc_manager_}; + NamespaceRouter router_; +}; + +} // namespace + +TEST_F(NamespaceRouterTest, NoHandlersNoop) { + CastMessage message; + message.set_namespace_("anzrfcnpr"); + router_.OnMessage(&vc_router_, socket(), std::move(message)); +} + +TEST_F(NamespaceRouterTest, MultipleHandlers) { + MockCastMessageHandler media_handler; + MockCastMessageHandler auth_handler; + MockCastMessageHandler connection_handler; + + router_.AddNamespaceHandler("media", &media_handler); + router_.AddNamespaceHandler("auth", &auth_handler); + router_.AddNamespaceHandler("connection", &connection_handler); + + EXPECT_CALL(media_handler, OnMessage(_, _, _)).Times(0); + EXPECT_CALL(auth_handler, OnMessage(_, _, _)) + .WillOnce(Invoke([](VirtualConnectionRouter* router, CastSocket*, + CastMessage message) { + EXPECT_EQ(message.namespace_(), "auth"); + })); + EXPECT_CALL(connection_handler, OnMessage(_, _, _)) + .WillOnce(Invoke([](VirtualConnectionRouter* router, CastSocket*, + CastMessage message) { + EXPECT_EQ(message.namespace_(), "connection"); + })); + + CastMessage auth_message; + auth_message.set_namespace_("auth"); + router_.OnMessage(&vc_router_, socket(), std::move(auth_message)); + + CastMessage connection_message; + connection_message.set_namespace_("connection"); + router_.OnMessage(&vc_router_, socket(), std::move(connection_message)); +} + +TEST_F(NamespaceRouterTest, RemoveHandler) { + MockCastMessageHandler handler1; + MockCastMessageHandler handler2; + + router_.AddNamespaceHandler("one", &handler1); + router_.AddNamespaceHandler("two", &handler2); + + router_.RemoveNamespaceHandler("one"); + + EXPECT_CALL(handler1, OnMessage(_, _, _)).Times(0); + EXPECT_CALL(handler2, OnMessage(_, _, _)) + .WillOnce(Invoke( + [](VirtualConnectionRouter* router, CastSocket* socket, + CastMessage message) { EXPECT_EQ("two", message.namespace_()); })); + + CastMessage message1; + message1.set_namespace_("one"); + router_.OnMessage(&vc_router_, socket(), std::move(message1)); + + CastMessage message2; + message2.set_namespace_("two"); + router_.OnMessage(&vc_router_, socket(), std::move(message2)); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto b/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto index 5689e364a1f..1b99b8ee604 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto +++ b/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto @@ -6,7 +6,11 @@ syntax = "proto2"; option optimize_for = LITE_RUNTIME; -package cast_channel.proto; +// TODO(crbug.com/openscreen/90): Rename to openscreen.cast, to update to the +// current namespacing of the library. Also, this file should probably be moved +// to the public directory. And, all of this will have to be coordinated with a +// DEPS roll in Chromium (since Chromium code depends on this). +package cast.channel; message AuthorityKeys { message Key { diff --git a/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto b/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto index 57c7b3f3fb7..58fb6c4dc6a 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto +++ b/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto @@ -6,12 +6,21 @@ syntax = "proto2"; option optimize_for = LITE_RUNTIME; +// TODO(crbug.com/openscreen/90): Rename to openscreen.cast, to update to the +// current namespacing of the library. Also, this file should probably be moved +// to the public directory. And, all of this will have to be coordinated with a +// DEPS roll in Chromium (since Chromium code depends on this). package cast.channel; message CastMessage { // Always pass a version of the protocol for future compatibility // requirements. - enum ProtocolVersion { CASTV2_1_0 = 0; } + enum ProtocolVersion { + CASTV2_1_0 = 0; + CASTV2_1_1 = 1; // message chunking support (deprecated). + CASTV2_1_2 = 2; // reworked message chunking. + CASTV2_1_3 = 3; // binary payload over utf8. + } required ProtocolVersion protocol_version = 1; // source and destination ids identify the origin and destination of the @@ -49,6 +58,20 @@ message CastMessage { // will always be set. optional string payload_utf8 = 6; optional bytes payload_binary = 7; + + // --- Begin new 1.1 fields. + + // Flag indicating whether there are more chunks to follow for this message. + // If the flag is false or is not present, then this is the last (or only) + // chunk of the message. + optional bool continued = 8; + + // If this is a chunk of a larger message, and the remaining length of the + // message payload (the sum of the lengths of the payloads of the remaining + // chunks) is known, this field will indicate that length. For a given + // chunked message, this field should either be present in all of the chunks, + // or in none of them. + optional uint32 remaining_length = 9; } enum SignatureAlgorithm { diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h new file mode 100644 index 00000000000..5373620bf03 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h @@ -0,0 +1,108 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_ +#define CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_ + +#include <memory> + +#include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "gmock/gmock.h" +#include "platform/test/mock_tls_connection.h" + +namespace openscreen { +namespace cast { + +class MockCastSocketClient final : public CastSocket::Client { + public: + ~MockCastSocketClient() override = default; + + MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override)); + MOCK_METHOD(void, + OnMessage, + (CastSocket * socket, ::cast::channel::CastMessage message), + (override)); +}; + +struct FakeCastSocket { + FakeCastSocket() + : FakeCastSocket({{10, 0, 1, 7}, 1234}, {{10, 0, 1, 9}, 4321}) {} + FakeCastSocket(const IPEndpoint& local_endpoint, + const IPEndpoint& remote_endpoint) + : local_endpoint(local_endpoint), + remote_endpoint(remote_endpoint), + moved_connection(std::make_unique<MockTlsConnection>(local_endpoint, + remote_endpoint)), + connection(moved_connection.get()), + socket(std::move(moved_connection), &mock_client) {} + + IPEndpoint local_endpoint; + IPEndpoint remote_endpoint; + std::unique_ptr<MockTlsConnection> moved_connection; + MockTlsConnection* connection; + MockCastSocketClient mock_client; + CastSocket socket; +}; + +// Two FakeCastSockets that are piped together via their MockTlsConnection +// read/write methods. Calling SendMessage on |socket| will result in an +// OnMessage callback on |mock_peer_client| and vice versa for |peer_socket| and +// |mock_client|. +struct FakeCastSocketPair { + FakeCastSocketPair() + : FakeCastSocketPair({{10, 0, 1, 7}, 1234}, {{10, 0, 1, 9}, 4321}) {} + + FakeCastSocketPair(const IPEndpoint& local_endpoint, + const IPEndpoint& remote_endpoint) + : local_endpoint(local_endpoint), remote_endpoint(remote_endpoint) { + using ::testing::_; + using ::testing::Invoke; + + auto moved_connection = + std::make_unique<::testing::NiceMock<MockTlsConnection>>( + local_endpoint, remote_endpoint); + connection = moved_connection.get(); + socket = + std::make_unique<CastSocket>(std::move(moved_connection), &mock_client); + + auto moved_peer = std::make_unique<::testing::NiceMock<MockTlsConnection>>( + remote_endpoint, local_endpoint); + peer_connection = moved_peer.get(); + peer_socket = + std::make_unique<CastSocket>(std::move(moved_peer), &mock_peer_client); + + ON_CALL(*connection, Send(_, _)) + .WillByDefault(Invoke([this](const void* data, size_t len) { + peer_connection->OnRead(std::vector<uint8_t>( + reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + return true; + })); + ON_CALL(*peer_connection, Send(_, _)) + .WillByDefault(Invoke([this](const void* data, size_t len) { + connection->OnRead(std::vector<uint8_t>( + reinterpret_cast<const uint8_t*>(data), + reinterpret_cast<const uint8_t*>(data) + len)); + return true; + })); + } + ~FakeCastSocketPair() = default; + + IPEndpoint local_endpoint; + IPEndpoint remote_endpoint; + + ::testing::NiceMock<MockTlsConnection>* connection; + MockCastSocketClient mock_client; + std::unique_ptr<CastSocket> socket; + + ::testing::NiceMock<MockTlsConnection>* peer_connection; + MockCastSocketClient mock_peer_client; + std::unique_ptr<CastSocket> peer_socket; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h new file mode 100644 index 00000000000..48abefdc90f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h @@ -0,0 +1,28 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_ +#define CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_ + +#include "cast/common/channel/cast_message_handler.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "gmock/gmock.h" + +namespace openscreen { +namespace cast { + +class MockCastMessageHandler final : public CastMessageHandler { + public: + MOCK_METHOD(void, + OnMessage, + (VirtualConnectionRouter * router, + CastSocket* socket, + ::cast::channel::CastMessage message), + (override)); +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h new file mode 100644 index 00000000000..60544b75192 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h @@ -0,0 +1,25 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_ +#define CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_ + +#include "cast/common/channel/virtual_connection_router.h" +#include "gmock/gmock.h" +#include "platform/base/error.h" + +namespace openscreen { +namespace cast { + +class MockSocketErrorHandler + : public VirtualConnectionRouter::SocketErrorHandler { + public: + MOCK_METHOD(void, OnClose, (CastSocket * socket), (override)); + MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override)); +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h index 0117bf51598..04f3ba06ab4 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h @@ -9,8 +9,8 @@ #include <cstdint> #include <string> +namespace openscreen { namespace cast { -namespace channel { // Transport system on top of CastSocket that allows routing messages over a // single socket to different virtual endpoints (e.g. system messages vs. @@ -35,16 +35,23 @@ struct VirtualConnection { // - Receiver app can only send broadcast messages over an invisible // connection. kInvisible, + + kMinValue = kStrong, + kMaxValue = kInvisible, }; // Cast V2 protocol version constants. Must be in sync with // proto/cast_channel.proto. enum class ProtocolVersion { kV2_1_0, + kV2_1_1, + kV2_1_2, + kV2_1_3, }; enum CloseReason { kUnknown, + kFirstReason = kUnknown, // Underlying socket has been closed by peer. This happens when Cast sender // closed transport connection normally without graceful virtual connection @@ -68,6 +75,7 @@ struct VirtualConnection { // The virtual connection has been closed by the peer gracefully. kClosedByPeer, + kLastReason = kClosedByPeer, }; struct AssociatedData { @@ -91,7 +99,7 @@ struct VirtualConnection { // app on the device. std::string local_id; std::string peer_id; - uint32_t socket_id; + int socket_id; }; inline bool operator==(const VirtualConnection& a, const VirtualConnection& b) { @@ -103,7 +111,7 @@ inline bool operator!=(const VirtualConnection& a, const VirtualConnection& b) { return !(a == b); } -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc index 8bac25e203a..86e06706a55 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc @@ -6,8 +6,8 @@ #include <type_traits> +namespace openscreen { namespace cast { -namespace channel { VirtualConnectionManager::VirtualConnectionManager() = default; @@ -15,7 +15,7 @@ VirtualConnectionManager::~VirtualConnectionManager() = default; void VirtualConnectionManager::AddConnection( VirtualConnection virtual_connection, - VirtualConnection::AssociatedData&& associated_data) { + VirtualConnection::AssociatedData associated_data) { auto& socket_map = connections_[virtual_connection.socket_id]; auto local_entries = socket_map.equal_range(virtual_connection.local_id); auto it = std::find_if( @@ -81,7 +81,7 @@ size_t VirtualConnectionManager::RemoveConnectionsByLocalId( } size_t VirtualConnectionManager::RemoveConnectionsBySocketId( - uint32_t socket_id, + int socket_id, VirtualConnection::CloseReason reason) { auto entry = connections_.find(socket_id); if (entry == connections_.end()) { @@ -115,5 +115,5 @@ VirtualConnectionManager::GetConnectionData( return absl::nullopt; } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h index 5bb5fdb1a39..902d2a969b0 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h @@ -12,8 +12,8 @@ #include "absl/types/optional.h" #include "cast/common/channel/virtual_connection.h" +namespace openscreen { namespace cast { -namespace channel { // Maintains a collection of open VirtualConnections and associated data. class VirtualConnectionManager { @@ -22,7 +22,7 @@ class VirtualConnectionManager { ~VirtualConnectionManager(); void AddConnection(VirtualConnection virtual_connection, - VirtualConnection::AssociatedData&& associated_data); + VirtualConnection::AssociatedData associated_data); // Returns true if a connection matching |virtual_connection| was found and // removed. @@ -32,7 +32,7 @@ class VirtualConnectionManager { // Returns the number of connections removed. size_t RemoveConnectionsByLocalId(const std::string& local_id, VirtualConnection::CloseReason reason); - size_t RemoveConnectionsBySocketId(uint32_t socket_id, + size_t RemoveConnectionsBySocketId(int socket_id, VirtualConnection::CloseReason reason); // Returns the AssociatedData for |virtual_connection| if a connection exists, @@ -50,12 +50,12 @@ class VirtualConnectionManager { VirtualConnection::AssociatedData data; }; - std::map<uint32_t /* socket_id */, + std::map<int /* socket_id */, std::multimap<std::string /* local_id */, VCTail>> connections_; }; -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc index 8edf105adac..963fcac728f 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc @@ -8,13 +8,22 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace channel { namespace { -static_assert(CastMessage_ProtocolVersion_CASTV2_1_0 == +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 == static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0), "V2 1.0 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1), + "V2 1.1 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2), + "V2 1.2 constants must be equal"); +static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 == + static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3), + "V2 1.3 constants must be equal"); using ::testing::_; using ::testing::Invoke; @@ -129,5 +138,5 @@ TEST_F(VirtualConnectionManagerTest, RemoveConnectionsByIds) { EXPECT_FALSE(manager_.GetConnectionData(vc3_)); } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc index 9a5d12832f0..7ceace603d9 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc @@ -11,8 +11,10 @@ #include "cast/common/channel/virtual_connection_manager.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace channel { + +using ::cast::channel::CastMessage; VirtualConnectionRouter::VirtualConnectionRouter( VirtualConnectionManager* vc_manager) @@ -35,14 +37,16 @@ bool VirtualConnectionRouter::RemoveHandlerForLocalId( void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr<CastSocket> socket) { - uint32_t id = socket->socket_id(); + int id = socket->socket_id(); socket->SetClient(this); sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler}); } -void VirtualConnectionRouter::CloseSocket(uint32_t id) { +void VirtualConnectionRouter::CloseSocket(int id) { auto it = sockets_.find(id); if (it != sockets_.end()) { + vc_manager_->RemoveConnectionsBySocketId( + id, VirtualConnection::kTransportClosed); std::unique_ptr<CastSocket> socket = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); @@ -50,26 +54,27 @@ void VirtualConnectionRouter::CloseSocket(uint32_t id) { } } -Error VirtualConnectionRouter::SendMessage(VirtualConnection vconn, - CastMessage&& message) { +Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn, + CastMessage message) { // TODO(btolsch): Check for broadcast message. if (!IsTransportNamespace(message.namespace_()) && - !vc_manager_->GetConnectionData(vconn)) { - return Error::Code::kUnknownError; + !vc_manager_->GetConnectionData(virtual_conn)) { + return Error::Code::kNoActiveConnection; } - auto it = sockets_.find(vconn.socket_id); + auto it = sockets_.find(virtual_conn.socket_id); if (it == sockets_.end()) { - return Error::Code::kUnknownError; + return Error::Code::kItemNotFound; } - message.set_source_id(std::move(vconn.local_id)); - message.set_destination_id(std::move(vconn.peer_id)); + message.set_source_id(std::move(virtual_conn.local_id)); + message.set_destination_id(std::move(virtual_conn.peer_id)); return it->second.socket->SendMessage(message); } void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { - uint32_t id = socket->socket_id(); + int id = socket->socket_id(); auto it = sockets_.find(id); if (it != sockets_.end()) { + vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown); std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket); SocketErrorHandler* error_handler = it->second.error_handler; sockets_.erase(it); @@ -80,10 +85,10 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) { void VirtualConnectionRouter::OnMessage(CastSocket* socket, CastMessage message) { // TODO(btolsch): Check for broadcast message. - VirtualConnection vconn{message.destination_id(), message.source_id(), - socket->socket_id()}; + VirtualConnection virtual_conn{message.destination_id(), message.source_id(), + socket->socket_id()}; if (!IsTransportNamespace(message.namespace_()) && - !vc_manager_->GetConnectionData(vconn)) { + !vc_manager_->GetConnectionData(virtual_conn)) { return; } const std::string& local_id = message.destination_id(); @@ -93,5 +98,5 @@ void VirtualConnectionRouter::OnMessage(CastSocket* socket, } } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h index 8ec643bd779..375c24ff50b 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h @@ -11,9 +11,10 @@ #include <string> #include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +namespace openscreen { namespace cast { -namespace channel { class CastMessageHandler; struct VirtualConnection; @@ -56,13 +57,15 @@ class VirtualConnectionRouter final : public CastSocket::Client { // |error_handler| must live until either its OnError or OnClose is called. void TakeSocket(SocketErrorHandler* error_handler, std::unique_ptr<CastSocket> socket); - void CloseSocket(uint32_t id); + void CloseSocket(int id); - Error SendMessage(VirtualConnection vconn, CastMessage&& message); + Error SendMessage(VirtualConnection virtual_conn, + ::cast::channel::CastMessage message); // CastSocket::Client overrides. void OnError(CastSocket* socket, Error error) override; - void OnMessage(CastSocket* socket, CastMessage message) override; + void OnMessage(CastSocket* socket, + ::cast::channel::CastMessage message) override; private: struct SocketWithHandler { @@ -71,11 +74,11 @@ class VirtualConnectionRouter final : public CastSocket::Client { }; VirtualConnectionManager* const vc_manager_; - std::map<uint32_t, SocketWithHandler> sockets_; + std::map<int, SocketWithHandler> sockets_; std::map<std::string /* local_id */, CastMessageHandler*> endpoints_; }; -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_ROUTER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc index 12d82f9096d..57e9fa1b883 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc @@ -6,30 +6,24 @@ #include "cast/common/channel/cast_socket.h" #include "cast/common/channel/proto/cast_channel.pb.h" -#include "cast/common/channel/test/fake_cast_socket.h" -#include "cast/common/channel/test/mock_cast_message_handler.h" +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_cast_message_handler.h" +#include "cast/common/channel/testing/mock_socket_error_handler.h" #include "cast/common/channel/virtual_connection_manager.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace channel { namespace { +using ::cast::channel::CastMessage; using ::testing::_; using ::testing::Invoke; -class MockSocketErrorHandler - : public VirtualConnectionRouter::SocketErrorHandler { - public: - MOCK_METHOD(void, OnClose, (CastSocket * socket), (override)); - MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override)); -}; - class VirtualConnectionRouterTest : public ::testing::Test { public: void SetUp() override { - socket_id_ = fake_cast_socket_pair_.socket->socket_id(); + socket_ = fake_cast_socket_pair_.socket.get(); router_.TakeSocket(&mock_error_handler_, std::move(fake_cast_socket_pair_.socket)); } @@ -38,7 +32,7 @@ class VirtualConnectionRouterTest : public ::testing::Test { CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; } FakeCastSocketPair fake_cast_socket_pair_; - uint32_t socket_id_; + CastSocket* socket_; MockSocketErrorHandler mock_error_handler_; MockCastMessageHandler mock_message_handler_; @@ -52,53 +46,59 @@ class VirtualConnectionRouterTest : public ::testing::Test { TEST_F(VirtualConnectionRouterTest, LocalIdHandler) { router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_); manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-9873", socket_id_}, {}); + VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()}, + {}); CastMessage message; - message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); message.set_namespace_("zrqvn"); message.set_source_id("sender-9873"); message.set_destination_id("receiver-1234"); message.set_payload_type(CastMessage::STRING); message.set_payload_utf8("cnlybnq"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)); - peer_socket().SendMessage(message); + EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); + EXPECT_TRUE(peer_socket().SendMessage(message).ok()); - EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)); - peer_socket().SendMessage(message); + EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); + EXPECT_TRUE(peer_socket().SendMessage(message).ok()); message.set_destination_id("receiver-4321"); EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0); - peer_socket().SendMessage(message); + EXPECT_TRUE(peer_socket().SendMessage(message).ok()); } TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) { router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_); manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-9873", socket_id_}, {}); + VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()}, + {}); CastMessage message; - message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); message.set_namespace_("zrqvn"); message.set_source_id("sender-9873"); message.set_destination_id("receiver-1234"); message.set_payload_type(CastMessage::STRING); message.set_payload_utf8("cnlybnq"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)); - peer_socket().SendMessage(message); + EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)); + EXPECT_TRUE(peer_socket().SendMessage(message).ok()); router_.RemoveHandlerForLocalId("receiver-1234"); - EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0); - peer_socket().SendMessage(message); + EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)).Times(0); + EXPECT_TRUE(peer_socket().SendMessage(message).ok()); } TEST_F(VirtualConnectionRouterTest, SendMessage) { manager_.AddConnection( - VirtualConnection{"receiver-1234", "sender-4321", socket_id_}, {}); + VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, + {}); CastMessage message; - message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); message.set_namespace_("zrqvn"); message.set_source_id("receiver-1234"); message.set_destination_id("sender-4321"); @@ -109,13 +109,25 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) { EXPECT_EQ(message.namespace_(), "zrqvn"); EXPECT_EQ(message.source_id(), "receiver-1234"); EXPECT_EQ(message.destination_id(), "sender-4321"); - ASSERT_EQ(message.payload_type(), CastMessage_PayloadType_STRING); + ASSERT_EQ(message.payload_type(), + ::cast::channel::CastMessage_PayloadType_STRING); EXPECT_EQ(message.payload_utf8(), "cnlybnq"); })); router_.SendMessage( - VirtualConnection{"receiver-1234", "sender-4321", socket_id_}, + VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, std::move(message)); } -} // namespace channel +TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) { + manager_.AddConnection( + VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()}, + {}); + + int id = socket_->socket_id(); + router_.CloseSocket(id); + EXPECT_FALSE(manager_.GetConnectionData( + VirtualConnection{"receiver-1234", "sender-4321", id})); +} + } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS new file mode 100644 index 00000000000..75ecfdce80a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS @@ -0,0 +1,3 @@ +include_rules = [ + '+platform/impl', +] diff --git a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc new file mode 100644 index 00000000000..2a2cb62ad54 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc @@ -0,0 +1,578 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <atomic> +#include <functional> +#include <map> + +#include "cast/common/public/service_info.h" +#include "discovery/common/config.h" +#include "discovery/common/reporting_client.h" +#include "discovery/public/dns_sd_service_factory.h" +#include "discovery/public/dns_sd_service_publisher.h" +#include "discovery/public/dns_sd_service_watcher.h" +#include "gtest/gtest.h" +#include "platform/api/logging.h" +#include "platform/api/udp_socket.h" +#include "platform/base/interface_info.h" +#include "platform/impl/logging.h" +#include "platform/impl/network_interface.h" +#include "platform/impl/platform_client_posix.h" +#include "platform/impl/task_runner.h" + +namespace openscreen { +namespace cast { +namespace { + +// Total wait time = 4 seconds. +constexpr std::chrono::milliseconds kWaitLoopSleepTime = + std::chrono::milliseconds(500); +constexpr int kMaxWaitLoopIterations = 8; + +// Total wait time = 2.5 seconds. +// NOTE: This must be less than the above wait time. +constexpr std::chrono::milliseconds kCheckLoopSleepTime = + std::chrono::milliseconds(100); +constexpr int kMaxCheckLoopIterations = 25; + +} // namespace + +// Publishes new service instances. +class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { + public: + Publisher(discovery::DnsSdService* service) + : DnsSdServicePublisher<ServiceInfo>(service, + kCastV2ServiceId, + ServiceInfoToDnsSdRecord) { + OSP_LOG << "Initializing Publisher...\n"; + } + + ~Publisher() override = default; + + bool IsInstanceIdClaimed(const std::string& requested_id) { + auto it = + std::find(instance_ids_.begin(), instance_ids_.end(), requested_id); + return it != instance_ids_.end(); + } + + private: + // DnsSdPublisher::Client overrides. + void OnInstanceClaimed(const std::string& requested_id) override { + instance_ids_.push_back(requested_id); + } + + std::vector<std::string> instance_ids_; +}; + +// Receives incoming services and outputs their results to stdout. +class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { + public: + Receiver(discovery::DnsSdService* service) + : discovery::DnsSdServiceWatcher<ServiceInfo>( + service, + kCastV2ServiceId, + DnsSdRecordToServiceInfo, + [this]( + std::vector<std::reference_wrapper<const ServiceInfo>> infos) { + ProcessResults(std::move(infos)); + }) { + OSP_LOG << "Initializing Receiver..."; + } + + bool IsServiceFound(const ServiceInfo& check_service) { + return std::find_if(service_infos_.begin(), service_infos_.end(), + [&check_service](const ServiceInfo& info) { + return info.friendly_name == + check_service.friendly_name; + }) != service_infos_.end(); + } + + void EraseReceivedServices() { service_infos_.clear(); } + + private: + void ProcessResults( + std::vector<std::reference_wrapper<const ServiceInfo>> infos) { + service_infos_.clear(); + for (const ServiceInfo& info : infos) { + service_infos_.push_back(info); + } + } + + std::vector<ServiceInfo> service_infos_; +}; + +class FailOnErrorReporting : public discovery::ReportingClient { + void OnFatalError(Error error) override { + // TODO(rwkeane): Change this to OSP_NOTREACHED() pending resolution of + // socket initialization issue. + OSP_LOG << "Fatal error received: '" << error << "'"; + } + + void OnRecoverableError(Error error) override { + // Pending resolution of openscreen:105, logging recoverable errors is + // disabled, as this will end up polluting the output with logs related to + // mDNS messages received from non-loopback network interfaces over which + // we have no control. + } +}; + +discovery::Config GetConfigSettings() { + discovery::Config config; + + // Get the loopback interface to run on. + absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); + OSP_DCHECK(loopback.has_value()); + config.interface = loopback.value(); + + return config; +} + +class DiscoveryE2ETest : public testing::Test { + public: + DiscoveryE2ETest() { + // Sleep to let any packets clear off the network before further tests. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Set log level so info logs go to stdout. + SetLogLevel(LogLevel::kInfo); + + PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}); + task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner(); + } + + ~DiscoveryE2ETest() { + OSP_LOG << "TEST COMPLETE!"; + dnssd_service_.reset(); + PlatformClientPosix::ShutDown(); + } + + protected: + ServiceInfo GetInfoV4() { + ServiceInfo hosted_service; + hosted_service.v4_address = IPAddress{10, 0, 0, 1}; + hosted_service.port = 25252; + hosted_service.unique_id = "id1"; + hosted_service.model_name = "openscreen-ModelV4"; + hosted_service.friendly_name = "DemoV4!"; + return hosted_service; + } + + ServiceInfo GetInfoV6() { + ServiceInfo hosted_service; + hosted_service.v6_address = IPAddress{1, 2, 3, 4, 5, 6, 7, 8}; + hosted_service.port = 25252; + hosted_service.unique_id = "id2"; + hosted_service.model_name = "openscreen-ModelV6"; + hosted_service.friendly_name = "DemoV6!"; + return hosted_service; + } + + ServiceInfo GetInfoV4V6() { + ServiceInfo hosted_service; + hosted_service.v4_address = IPAddress{10, 0, 0, 2}; + hosted_service.v6_address = IPAddress{1, 2, 3, 4, 5, 6, 7, 9}; + hosted_service.port = 25254; + hosted_service.unique_id = "id3"; + hosted_service.model_name = "openscreen-ModelV4andV6"; + hosted_service.friendly_name = "DemoV4andV6!"; + return hosted_service; + } + + void SetUpService(const discovery::Config& config) { + OSP_DCHECK(!dnssd_service_.get()); + dnssd_service_ = + discovery::CreateDnsSdService(task_runner_, &reporting_client_, config); + receiver_ = std::make_unique<Receiver>(dnssd_service_.get()); + publisher_ = std::make_unique<Publisher>(dnssd_service_.get()); + } + + void StartDiscovery() { + OSP_DCHECK(dnssd_service_.get()); + task_runner_->PostTask([this]() { receiver_->StartDiscovery(); }); + } + + template <typename... RecordTypes> + void UpdateRecords(RecordTypes... records) { + OSP_DCHECK(dnssd_service_.get()); + OSP_DCHECK(publisher_.get()); + + std::vector<ServiceInfo> record_set{std::move(records)...}; + for (ServiceInfo& record : record_set) { + task_runner_->PostTask([this, r = std::move(record)]() { + auto error = publisher_->UpdateRegistration(r); + OSP_DCHECK(error.ok()) << "\tFailed to update service instance '" + << r.friendly_name << "': " << error << "!"; + }); + } + } + + template <typename... RecordTypes> + void PublishRecords(RecordTypes... records) { + OSP_DCHECK(dnssd_service_.get()); + OSP_DCHECK(publisher_.get()); + + std::vector<ServiceInfo> record_set{std::move(records)...}; + for (ServiceInfo& record : record_set) { + task_runner_->PostTask([this, r = std::move(record)]() { + auto error = publisher_->Register(r); + OSP_DCHECK(error.ok()) << "\tFailed to publish service instance '" + << r.friendly_name << "': " << error << "!"; + }); + } + } + + template <typename... AtomicBoolPtrs> + void WaitUntilSeen(bool should_be_seen, AtomicBoolPtrs... bools) { + OSP_DCHECK(dnssd_service_.get()); + std::vector<std::atomic_bool*> atomic_bools{bools...}; + + int waiting_on = atomic_bools.size(); + for (int i = 0; i < kMaxWaitLoopIterations; i++) { + waiting_on = atomic_bools.size(); + for (std::atomic_bool* atomic : atomic_bools) { + if (*atomic) { + OSP_DCHECK(should_be_seen) << "Found service instance!"; + waiting_on--; + } + } + + if (waiting_on) { + OSP_LOG << "\tWaiting on " << waiting_on << "..."; + std::this_thread::sleep_for(kWaitLoopSleepTime); + continue; + } + return; + } + OSP_DCHECK(!should_be_seen) + << "Could not find " << waiting_on << " service instances!"; + } + + void CheckForClaimedIds(ServiceInfo service_info, + std::atomic_bool* has_been_seen) { + OSP_DCHECK(dnssd_service_.get()); + task_runner_->PostTask( + [this, info = std::move(service_info), has_been_seen]() mutable { + CheckForClaimedIds(std::move(info), has_been_seen, 0); + }); + } + + void CheckForPublishedService(ServiceInfo service_info, + std::atomic_bool* has_been_seen) { + OSP_DCHECK(dnssd_service_.get()); + task_runner_->PostTask( + [this, info = std::move(service_info), has_been_seen]() mutable { + CheckForPublishedService(std::move(info), has_been_seen, 0, true); + }); + } + + void CheckNotPublishedService(ServiceInfo service_info, + std::atomic_bool* has_been_seen) { + OSP_DCHECK(dnssd_service_.get()); + task_runner_->PostTask( + [this, info = std::move(service_info), has_been_seen]() mutable { + CheckForPublishedService(std::move(info), has_been_seen, 0, false); + }); + } + TaskRunner* task_runner_; + FailOnErrorReporting reporting_client_; + SerialDeletePtr<discovery::DnsSdService> dnssd_service_; + std::unique_ptr<Receiver> receiver_; + std::unique_ptr<Publisher> publisher_; + + private: + void CheckForClaimedIds(ServiceInfo service_info, + std::atomic_bool* has_been_seen, + int attempts) { + if (publisher_->IsInstanceIdClaimed(service_info.GetInstanceId())) { + // TODO(crbug.com/openscreen/110): Log the published service instance. + *has_been_seen = true; + return; + } + + if (attempts++ > kMaxCheckLoopIterations) { + OSP_NOTREACHED() << "Service " << service_info.friendly_name + << " publication failed."; + } + task_runner_->PostTaskWithDelay( + [this, info = std::move(service_info), has_been_seen, + attempts]() mutable { + CheckForClaimedIds(std::move(info), has_been_seen, attempts); + }, + kCheckLoopSleepTime); + } + + void CheckForPublishedService(ServiceInfo service_info, + std::atomic_bool* has_been_seen, + int attempts, + bool expect_to_be_present) { + if (!receiver_->IsServiceFound(service_info)) { + if (attempts++ > kMaxCheckLoopIterations) { + OSP_DCHECK(!expect_to_be_present) + << "Service " << service_info.friendly_name << " discovery failed."; + return; + } + task_runner_->PostTaskWithDelay( + [this, info = std::move(service_info), has_been_seen, attempts, + expect_to_be_present]() mutable { + CheckForPublishedService(std::move(info), has_been_seen, attempts, + expect_to_be_present); + }, + kCheckLoopSleepTime); + } else if (expect_to_be_present) { + // TODO(crbug.com/openscreen/110): Log the discovered service instance. + *has_been_seen = true; + } else { + OSP_NOTREACHED() << "Found instance '" << service_info.friendly_name + << "'!"; + } + } +}; + +// The below runs an E2E tests. These test requires no user interaction and is +// intended to perform a set series of actions to validate that discovery is +// functioning as intended. +// +// Known issues: +// - The ipv6 socket in discovery/mdns/service_impl.cc fails to bind to an ipv6 +// address on the loopback interface. Investigating this issue is pending +// resolution of bug +// https://bugs.chromium.org/p/openscreen/issues/detail?id=105. +// +// In this test, the following operations are performed: +// 1) Start up the Cast platform for a posix system. +// 2) Publish 3 CastV2 service instances to the loopback interface using mDNS, +// with record announcement disabled. +// 3) Wait for the probing phase to successfully complete. +// 4) Query for records published over the loopback interface, and validate that +// all 3 previously published services are discovered. +TEST_F(DiscoveryE2ETest, ValidateQueryFlow) { + // Set up demo infra. + auto discovery_config = GetConfigSettings(); + discovery_config.new_record_announcement_count = 0; + SetUpService(discovery_config); + + auto v4 = GetInfoV4(); + auto v6 = GetInfoV6(); + auto multi_address = GetInfoV4V6(); + + // Start discovery and publication. + StartDiscovery(); + PublishRecords(v4, v6, multi_address); + + // Wait until all probe phases complete and all instance ids are claimed. At + // this point, all records should be published. + OSP_LOG << "Service publication in progress..."; + std::atomic_bool v4_found{false}; + std::atomic_bool v6_found{false}; + std::atomic_bool multi_address_found{false}; + CheckForClaimedIds(v4, &v4_found); + CheckForClaimedIds(v6, &v6_found); + CheckForClaimedIds(multi_address, &multi_address_found); + WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found); + OSP_LOG << "\tAll services successfully published!\n"; + + // Make sure all services are found through discovery. + OSP_LOG << "Service discovery in progress..."; + v4_found = false; + v6_found = false; + multi_address_found = false; + CheckForPublishedService(v4, &v4_found); + CheckForPublishedService(v6, &v6_found); + CheckForPublishedService(multi_address, &multi_address_found); + WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found); +} + +// In this test, the following operations are performed: +// 1) Start up the Cast platform for a posix system. +// 2) Start service discovery and new queries, with no query messages being +// sent. +// 3) Publish 3 CastV2 service instances to the loopback interface using mDNS, +// with record announcement enabled. +// 4) Ensure the correct records were published over the loopback interface. +// 5) De-register all services. +// 6) Ensure that goodbye records are received for all service instances. +TEST_F(DiscoveryE2ETest, ValidateAnnouncementFlow) { + // Set up demo infra. + auto discovery_config = GetConfigSettings(); + discovery_config.new_query_announcement_count = 0; + SetUpService(discovery_config); + + auto v4 = GetInfoV4(); + auto v6 = GetInfoV6(); + auto multi_address = GetInfoV4V6(); + + // Start discovery and publication. + StartDiscovery(); + PublishRecords(v4, v6, multi_address); + + // Wait until all probe phases complete and all instance ids are claimed. At + // this point, all records should be published. + OSP_LOG << "Service publication in progress..."; + std::atomic_bool v4_found{false}; + std::atomic_bool v6_found{false}; + std::atomic_bool multi_address_found{false}; + CheckForClaimedIds(v4, &v4_found); + CheckForClaimedIds(v6, &v6_found); + CheckForClaimedIds(multi_address, &multi_address_found); + WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found); + OSP_LOG << "\tAll services successfully published and announced!\n"; + + // Make sure all services are found through discovery. + OSP_LOG << "Service discovery in progress..."; + v4_found = false; + v6_found = false; + multi_address_found = false; + CheckForPublishedService(v4, &v4_found); + CheckForPublishedService(v6, &v6_found); + CheckForPublishedService(multi_address, &multi_address_found); + WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found); +} + +// In this test, the following operations are performed: +// 1) Start up the Cast platform for a posix system. +// 2) Publish one service and ensure it is NOT received. +// 3) Start service discovery and new queries. +// 4) Ensure above published service IS received. +// 5) Stop the started query. +// 6) Update a service, and ensure that no callback is received. +// 7) Restart the query and ensure that only the expected callbacks are +// received. +TEST_F(DiscoveryE2ETest, ValidateRecordsOnlyReceivedWhenQueryRunning) { + // Set up demo infra. + auto discovery_config = GetConfigSettings(); + discovery_config.new_record_announcement_count = 1; + SetUpService(discovery_config); + + auto v4 = GetInfoV4(); + + // Start discovery and publication. + PublishRecords(v4); + + // Wait until all probe phases complete and all instance ids are claimed. At + // this point, all records should be published. + OSP_LOG << "Service publication in progress..."; + std::atomic_bool v4_found{false}; + CheckForClaimedIds(v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // And ensure stopped discovery does not find the records. + OSP_LOG << "Validating no service discovery occurs when discovery stopped..."; + v4_found = false; + CheckNotPublishedService(v4, &v4_found); + WaitUntilSeen(false, &v4_found); + + // Make sure all services are found through discovery. + StartDiscovery(); + OSP_LOG << "Service discovery in progress..."; + v4_found = false; + CheckForPublishedService(v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // Update discovery and ensure that the updated service is seen. + OSP_LOG << "Updating service and waiting for discovery..."; + auto updated_v4 = v4; + updated_v4.friendly_name = "OtherName"; + v4_found = false; + UpdateRecords(updated_v4); + CheckForPublishedService(updated_v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // And ensure the old service has been removed. + v4_found = false; + CheckNotPublishedService(v4, &v4_found); + WaitUntilSeen(false, &v4_found); + + // Stop discovery. + OSP_LOG << "Stopping discovery..."; + task_runner_->PostTask([this]() { receiver_->StopDiscovery(); }); + + // Update discovery and ensure that the updated service is NOT seen. + OSP_LOG << "Updating service and validating the change isn't received..."; + v4_found = false; + v4.friendly_name = "ThirdName"; + UpdateRecords(v4); + CheckNotPublishedService(v4, &v4_found); + WaitUntilSeen(false, &v4_found); + + // Restart discovery and ensure that only the updated record is returned. + StartDiscovery(); + OSP_LOG << "Service discovery in progress..."; + v4_found = false; + CheckNotPublishedService(updated_v4, &v4_found); + WaitUntilSeen(false, &v4_found); + + v4_found = false; + CheckForPublishedService(v4, &v4_found); + WaitUntilSeen(true, &v4_found); +} + +// In this test, the following operations are performed: +// 1) Start up the Cast platform for a posix system. +// 2) Start service discovery and new queries. +// 3) Publish one service and ensure it is received. +// 4) Hard reset discovery +// 5) Ensure the same service is discovered +// 6) Soft reset the service, and ensure that a callback is received. +TEST_F(DiscoveryE2ETest, ValidateRefreshFlow) { + // Set up demo infra. + // NOTE: This configuration assumes that packets cannot be lost over the + // loopback interface. + auto discovery_config = GetConfigSettings(); + discovery_config.new_record_announcement_count = 0; + discovery_config.new_query_announcement_count = 2; + constexpr std::chrono::seconds kMaxQueryDuration{3}; + SetUpService(discovery_config); + + auto v4 = GetInfoV4(); + + // Start discovery and publication. + StartDiscovery(); + PublishRecords(v4); + + // Wait until all probe phases complete and all instance ids are claimed. At + // this point, all records should be published. + OSP_LOG << "Service publication in progress..."; + std::atomic_bool v4_found{false}; + CheckForClaimedIds(v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // Make sure all services are found through discovery. + OSP_LOG << "Service discovery in progress..."; + v4_found = false; + CheckForPublishedService(v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // Force refresh discovery, then ensure that the published service is + // re-discovered. + OSP_LOG << "Force refresh discovery..."; + task_runner_->PostTask([this]() { receiver_->EraseReceivedServices(); }); + std::this_thread::sleep_for(kMaxQueryDuration); + v4_found = false; + CheckNotPublishedService(v4, &v4_found); + WaitUntilSeen(false, &v4_found); + task_runner_->PostTask([this]() { receiver_->ForceRefresh(); }); + + OSP_LOG << "Ensure that the published service is re-discovered..."; + v4_found = false; + CheckForPublishedService(v4, &v4_found); + WaitUntilSeen(true, &v4_found); + + // Soft refresh discovery, then ensure that the published service is NOT + // re-discovered. + OSP_LOG << "Call DiscoverNow on discovery..."; + task_runner_->PostTask([this]() { receiver_->EraseReceivedServices(); }); + std::this_thread::sleep_for(kMaxQueryDuration); + v4_found = false; + CheckNotPublishedService(v4, &v4_found); + WaitUntilSeen(false, &v4_found); + task_runner_->PostTask([this]() { receiver_->DiscoverNow(); }); + + OSP_LOG << "Ensure that the published service is re-discovered..."; + v4_found = false; + CheckForPublishedService(v4, &v4_found); + WaitUntilSeen(true, &v4_found); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/public/DEPS b/chromium/third_party/openscreen/src/cast/common/public/DEPS index 7060c48007f..c098d4db36a 100644 --- a/chromium/third_party/openscreen/src/cast/common/public/DEPS +++ b/chromium/third_party/openscreen/src/cast/common/public/DEPS @@ -1,12 +1,8 @@ # -*- Mode: Python; -*- include_rules = [ - # By default, openscreen implementation libraries should not be exposed - # through public APIs. - '-base', - '-platform', - # Dependencies on the implementation are not allowed in public/. '-cast/common', - '+cast/common/public' + '+cast/common/public', + '+discovery/dnssd/public' ] diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc new file mode 100644 index 00000000000..a886f7d2c6b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc @@ -0,0 +1,243 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/public/service_info.h" + +#include <cctype> +#include <memory> +#include <string> +#include <vector> + +#include "absl/strings/numbers.h" +#include "absl/strings/str_replace.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +// The mask for the set of supported capabilities in v2 of the Cast protocol. +const uint16_t kCapabilitiesMask = 0x1F; + +// Maximum size for registered MDNS service instance names. +const size_t kMaxDeviceNameSize = 63; + +// Maximum size for the device model prefix at start of MDNS service instance +// names. Any model names that are larger than this size will be truncated. +const size_t kMaxDeviceModelSize = 20; + +// Build the MDNS instance name for service. This will be the device model (up +// to 20 bytes) appended with the virtual device ID (device UUID) and optionally +// appended with extension at the end to resolve name conflicts. The total MDNS +// service instance name is kept below 64 bytes so it can easily fit into a +// single domain name label. +// +// NOTE: This value is based on what is currently done by Eureka, not what is +// called out in the CastV2 spec. Eureka uses |model|-|uuid|, so the same +// convention will be followed here. That being said, the Eureka receiver does +// not use the instance ID in any way, so the specific calculation used should +// not be important. +std::string CalculateInstanceId(const ServiceInfo& info) { + // First set the device model, truncated to 20 bytes at most. Replace any + // whitespace characters (" ") with hyphens ("-") in the device model before + // truncation. + std::string instance_name = + absl::StrReplaceAll(info.model_name, {{" ", "-"}}); + instance_name = std::string(instance_name, 0, kMaxDeviceModelSize); + + // Append the virtual device ID to the instance name separated by a single + // '-' character if not empty. Strip all hyphens from the device ID prior + // to appending it. + std::string device_id = absl::StrReplaceAll(info.unique_id, {{"-", ""}}); + + if (!instance_name.empty()) { + instance_name.push_back('-'); + } + instance_name.append(device_id); + + return std::string(instance_name, 0, kMaxDeviceNameSize); +} + +// NOTE: Eureka uses base::StringToUint64 which takes in a string and reads it +// left to right, converts it to a number sequence, and uses the sequence to +// calculate the resulting integer. This process assumes that the input is in +// base 10. For example, ['1', '2', '3'] converts to 123. +// +// The below 2 functions re-create this logic for converting to and from this +// encoding scheme. +inline std::string EncodeIntegerString(uint64_t value) { + return std::to_string(value); +} + +ErrorOr<uint64_t> DecodeIntegerString(const std::string& value) { + uint64_t result; + if (!absl::SimpleAtoi(value, &result)) { + return Error::Code::kParameterInvalid; + } + + return result; +} + +// Attempts to parse the string present at the provided |key| in the TXT record +// |txt|, placing the result into |result| on success and error into |error| on +// failure. +bool TryParseString(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + Error* error, + std::string* result) { + const ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key); + if (value.is_error()) { + *error = value.error(); + return false; + } + + const std::vector<uint8_t>& txt_value = value.value().get(); + *result = std::string(txt_value.begin(), txt_value.end()); + return true; +} +// Attempts to parse the uint8_t present at the provided |key| in the TXT record +// |txt|, placing the result into |result| on success and error into |error| on +// failure. +bool TryParseInt(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + Error* error, + uint8_t* result) { + const ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key); + if (value.is_error()) { + *error = value.error(); + return false; + } + + const std::vector<uint8_t>& txt_value = value.value().get(); + if (txt_value.size() != 1) { + *error = Error::Code::kParameterInvalid; + return false; + } + + *result = txt_value[0]; + return true; +} + +// Simplifies logic below by changing error into an output parameter instead of +// a return value. +bool IsError(Error error, Error* result) { + if (error.ok()) { + return false; + } else { + *result = error; + return true; + } +} + +} // namespace + +const std::string& ServiceInfo::GetInstanceId() const { + if (instance_id_ == std::string("")) { + instance_id_ = CalculateInstanceId(*this); + } + + return instance_id_; +} + +bool ServiceInfo::IsValid() const { + std::string instance_id = GetInstanceId(); + if (!discovery::IsInstanceValid(instance_id)) { + return false; + } + + const std::string capabilities_str = EncodeIntegerString(capabilities); + if (!discovery::DnsSdTxtRecord::IsValidTxtValue(kUniqueIdKey, unique_id) || + !discovery::DnsSdTxtRecord::IsValidTxtValue(kVersionId, + protocol_version) || + !discovery::DnsSdTxtRecord::IsValidTxtValue(kCapabilitiesId, + capabilities_str) || + !discovery::DnsSdTxtRecord::IsValidTxtValue(kStatusId, status) || + !discovery::DnsSdTxtRecord::IsValidTxtValue(kFriendlyNameId, + friendly_name) || + !discovery::DnsSdTxtRecord::IsValidTxtValue(kModelNameId, model_name)) { + return false; + } + + return port && (v4_address || v6_address); +} + +discovery::DnsSdInstanceRecord ServiceInfoToDnsSdRecord( + const ServiceInfo& service) { + OSP_DCHECK(discovery::IsServiceValid(kCastV2ServiceId)); + OSP_DCHECK(discovery::IsDomainValid(kCastV2DomainId)); + + std::string instance_id = service.GetInstanceId(); + OSP_DCHECK(discovery::IsInstanceValid(instance_id)); + + const std::string capabilities_str = + EncodeIntegerString(service.capabilities); + + discovery::DnsSdTxtRecord txt; + Error error; + const bool set_txt = + !IsError(txt.SetValue(kUniqueIdKey, service.unique_id), &error) && + !IsError(txt.SetValue(kVersionId, service.protocol_version), &error) && + !IsError(txt.SetValue(kCapabilitiesId, capabilities_str), &error) && + !IsError(txt.SetValue(kStatusId, service.status), &error) && + !IsError(txt.SetValue(kFriendlyNameId, service.friendly_name), &error) && + !IsError(txt.SetValue(kModelNameId, service.model_name), &error); + OSP_DCHECK(set_txt); + + OSP_DCHECK(service.port); + OSP_DCHECK(service.v4_address || service.v6_address); + if (service.v4_address && service.v6_address) { + return discovery::DnsSdInstanceRecord( + instance_id, kCastV2ServiceId, kCastV2DomainId, + {service.v4_address, service.port}, {service.v6_address, service.port}, + std::move(txt)); + } else { + const IPAddress& address = + service.v4_address ? service.v4_address : service.v6_address; + return discovery::DnsSdInstanceRecord( + instance_id, kCastV2ServiceId, kCastV2DomainId, {address, service.port}, + std::move(txt)); + } +} + +ErrorOr<ServiceInfo> DnsSdRecordToServiceInfo( + const discovery::DnsSdInstanceRecord& instance) { + if (instance.service_id() != kCastV2ServiceId) { + return Error::Code::kParameterInvalid; + } + + ServiceInfo record; + record.v4_address = instance.address_v4().address; + record.v6_address = instance.address_v6().address; + record.port = instance.address_v4().port ? instance.address_v4().port + : instance.address_v6().port; + + const auto& txt = instance.txt(); + std::string capabilities_base64; + std::string unique_id; + uint8_t status; + Error error; + if (!TryParseInt(txt, kVersionId, &error, &record.protocol_version) || + !TryParseInt(txt, kStatusId, &error, &status) || + !TryParseString(txt, kFriendlyNameId, &error, &record.friendly_name) || + !TryParseString(txt, kModelNameId, &error, &record.model_name) || + !TryParseString(txt, kCapabilitiesId, &error, &capabilities_base64) || + !TryParseString(txt, kUniqueIdKey, &error, &record.unique_id)) { + return error; + } + + record.status = static_cast<ReceiverStatus>(status); + + const ErrorOr<uint64_t> capabilities_flags = + DecodeIntegerString(capabilities_base64); + if (capabilities_flags.is_error()) { + return capabilities_flags.error(); + } + record.capabilities = static_cast<ReceiverCapabilities>( + capabilities_flags.value() & kCapabilitiesMask); + + return record; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.h b/chromium/third_party/openscreen/src/cast/common/public/service_info.h new file mode 100644 index 00000000000..a2532440df0 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.h @@ -0,0 +1,124 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_PUBLIC_SERVICE_INFO_H_ +#define CAST_COMMON_PUBLIC_SERVICE_INFO_H_ + +#include <memory> +#include <string> + +#include "discovery/dnssd/public/dns_sd_instance_record.h" +#include "platform/base/ip_address.h" + +namespace openscreen { +namespace cast { + +// Constants to identify a CastV2 instance with DNS-SD. +static constexpr char kCastV2ServiceId[] = "_googlecast._tcp"; +static constexpr char kCastV2DomainId[] = "local"; + +// Constants to be used as keys when storing data inside of a DNS-SD TXT record. +static constexpr char kUniqueIdKey[] = "id"; +static constexpr char kVersionId[] = "ve"; +static constexpr char kCapabilitiesId[] = "ca"; +static constexpr char kStatusId[] = "st"; +static constexpr char kFriendlyNameId[] = "fn"; +static constexpr char kModelNameId[] = "mn"; + +// This represents the ‘st’ flag in the CastV2 TXT record. +enum ReceiverStatus { + // The receiver is idle and does not need to be connected now. + kIdle = 0, + + // The receiver is hosting an activity and invites the sender to join. The + // receiver should connect to the running activity using the channel + // establishment protocol, and then query the activity to determine the next + // step, such as showing a description of the activity and prompting the user + // to launch the corresponding app. + kBusy = 1, + kJoin = kBusy +}; + +// This represents the ‘ca’ field in the CastV2 spec. +enum ReceiverCapabilities : uint64_t { + kNone = 0x00, + kHasVideoOutput = 0x01 << 0, + kHasVideoInput = 0x01 << 1, + kHasAudioOutput = 0x01 << 2, + kHasAudioInput = 0x01 << 3, + kIsDevModeEnabled = 0x01 << 4, +}; + +static constexpr uint8_t kCurrentCastVersion = 2; +static constexpr ReceiverCapabilities kDefaultCapabilities = + ReceiverCapabilities::kNone; + +// This is the top-level service info class for CastV2. It describes a specific +// service instance. +// TODO(crbug.com/openscreen/112): Rename this to CastReceiverInfo or similar. +struct ServiceInfo { + // returns the instance id associated with this ServiceInfo instance. + const std::string& GetInstanceId() const; + + // Returns whether all fields of this ServiceInfo are valid. + bool IsValid() const; + + // Addresses for the service. Present if an address of this address type + // exists and empty otherwise. + IPAddress v4_address; + IPAddress v6_address; + + // Port at which this service can be reached. + uint16_t port; + + // A UUID for the Cast receiver. This should be a universally unique + // identifier for the receiver, and should (but does not have to be) be stable + // across factory resets. + std::string unique_id; + + // Cast protocol version supported. Begins at 2 and is incremented by 1 with + // each version. + uint8_t protocol_version = kCurrentCastVersion; + + // Capabilities supported by this service instance. + ReceiverCapabilities capabilities = kDefaultCapabilities; + + // Status of the service instance. + ReceiverStatus status = ReceiverStatus::kIdle; + + // The model name of the device, e.g. “Eureka v1”, “Mollie”. + std::string model_name; + + // The friendly name of the device, e.g. “Living Room TV". + std::string friendly_name; + + private: + mutable std::string instance_id_ = ""; +}; + +inline bool operator==(const ServiceInfo& lhs, const ServiceInfo& rhs) { + return lhs.v4_address == rhs.v4_address && lhs.v6_address == rhs.v6_address && + lhs.port == rhs.port && lhs.unique_id == rhs.unique_id && + lhs.protocol_version == rhs.protocol_version && + lhs.capabilities == rhs.capabilities && lhs.status == rhs.status && + lhs.model_name == rhs.model_name && + lhs.friendly_name == rhs.friendly_name; +} + +inline bool operator!=(const ServiceInfo& lhs, const ServiceInfo& rhs) { + return !(lhs == rhs); +} + +// Functions responsible for converting between CastV2 and DNS-SD +// representations of a service instance. +discovery::DnsSdInstanceRecord ServiceInfoToDnsSdRecord( + const ServiceInfo& service); + +ErrorOr<ServiceInfo> DnsSdRecordToServiceInfo( + const discovery::DnsSdInstanceRecord& service); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_PUBLIC_SERVICE_INFO_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc new file mode 100644 index 00000000000..291f3526138 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc @@ -0,0 +1,209 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/public/service_info.h" + +#include "cast/common/public/testing/discovery_utils.h" +#include "discovery/dnssd/public/dns_sd_instance_record.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace openscreen { +namespace cast { + +TEST(ServiceInfoTests, ConvertValidFromDnsSd) { + std::string instance = "InstanceId"; + discovery::DnsSdTxtRecord txt = CreateValidTxt(); + discovery::DnsSdInstanceRecord record(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + ErrorOr<ServiceInfo> info = DnsSdRecordToServiceInfo(record); + ASSERT_TRUE(info.is_value()); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_TRUE(info.value().v4_address); + EXPECT_EQ(info.value().v4_address, kAddressV4); + EXPECT_TRUE(info.value().v6_address); + EXPECT_EQ(info.value().v6_address, kAddressV6); + EXPECT_EQ(info.value().port, kPort); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_EQ(info.value().protocol_version, kTestVersion); + EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed); + EXPECT_EQ(info.value().status, kStatusParsed); + EXPECT_EQ(info.value().model_name, kModelName); + EXPECT_EQ(info.value().friendly_name, kFriendlyName); + + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, txt); + ASSERT_FALSE(record.address_v6()); + info = DnsSdRecordToServiceInfo(record); + ASSERT_TRUE(info.is_value()); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_TRUE(info.value().v4_address); + EXPECT_EQ(info.value().v4_address, kAddressV4); + EXPECT_FALSE(info.value().v6_address); + EXPECT_EQ(info.value().port, kPort); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_EQ(info.value().protocol_version, kTestVersion); + EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed); + EXPECT_EQ(info.value().status, kStatusParsed); + EXPECT_EQ(info.value().model_name, kModelName); + EXPECT_EQ(info.value().friendly_name, kFriendlyName); + + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV6, txt); + ASSERT_FALSE(record.address_v4()); + info = DnsSdRecordToServiceInfo(record); + ASSERT_TRUE(info.is_value()); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_FALSE(info.value().v4_address); + EXPECT_TRUE(info.value().v6_address); + EXPECT_EQ(info.value().v6_address, kAddressV6); + EXPECT_EQ(info.value().unique_id, kTestUniqueId); + EXPECT_EQ(info.value().protocol_version, kTestVersion); + EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed); + EXPECT_EQ(info.value().status, kStatusParsed); + EXPECT_EQ(info.value().model_name, kModelName); + EXPECT_EQ(info.value().friendly_name, kFriendlyName); +} + +TEST(ServiceInfoTests, ConvertInvalidFromDnsSd) { + std::string instance = "InstanceId"; + discovery::DnsSdTxtRecord txt = CreateValidTxt(); + txt.ClearValue(kUniqueIdKey); + discovery::DnsSdInstanceRecord record(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); + + txt = CreateValidTxt(); + txt.ClearValue(kVersionId); + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); + + txt = CreateValidTxt(); + txt.ClearValue(kCapabilitiesId); + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); + + txt = CreateValidTxt(); + txt.ClearValue(kStatusId); + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); + + txt = CreateValidTxt(); + txt.ClearValue(kFriendlyNameId); + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); + + txt = CreateValidTxt(); + txt.ClearValue(kModelNameId); + record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error()); +} + +TEST(ServiceInfoTests, ConvertValidToDnsSd) { + ServiceInfo info; + info.v4_address = kAddressV4; + info.v6_address = kAddressV6; + info.port = kPort; + info.unique_id = kTestUniqueId; + info.protocol_version = kTestVersion; + info.capabilities = kCapabilitiesParsed; + info.status = kStatusParsed; + info.model_name = kModelName; + info.friendly_name = kFriendlyName; + discovery::DnsSdInstanceRecord record = ServiceInfoToDnsSdRecord(info); + EXPECT_EQ(record.instance_id(), kInstanceId); + EXPECT_TRUE(record.address_v4()); + EXPECT_EQ(record.address_v4(), kEndpointV4); + EXPECT_TRUE(record.address_v6()); + EXPECT_EQ(record.address_v6(), kEndpointV6); + CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId); + CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString); + CompareTxtString(record.txt(), kModelNameId, kModelName); + CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName); + CompareTxtInt(record.txt(), kVersionId, kTestVersion); + CompareTxtInt(record.txt(), kStatusId, kStatus); + + info.v6_address = IPAddress{}; + record = ServiceInfoToDnsSdRecord(info); + EXPECT_TRUE(record.address_v4()); + EXPECT_EQ(record.address_v4(), kEndpointV4); + EXPECT_FALSE(record.address_v6()); + CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId); + CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString); + CompareTxtString(record.txt(), kModelNameId, kModelName); + CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName); + CompareTxtInt(record.txt(), kVersionId, kTestVersion); + CompareTxtInt(record.txt(), kStatusId, kStatus); + + info.v6_address = kAddressV6; + info.v4_address = IPAddress{}; + record = ServiceInfoToDnsSdRecord(info); + EXPECT_FALSE(record.address_v4()); + EXPECT_TRUE(record.address_v6()); + EXPECT_EQ(record.address_v6(), kEndpointV6); + CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId); + CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString); + CompareTxtString(record.txt(), kModelNameId, kModelName); + CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName); + CompareTxtInt(record.txt(), kVersionId, kTestVersion); + CompareTxtInt(record.txt(), kStatusId, kStatus); +} + +TEST(ServiceInfoTests, ConvertInvalidToDnsSd) { + ServiceInfo info; + info.unique_id = kTestUniqueId; + info.protocol_version = kTestVersion; + info.capabilities = kCapabilitiesParsed; + info.status = kStatusParsed; + info.model_name = kModelName; + info.friendly_name = kFriendlyName; + EXPECT_FALSE(info.IsValid()); +} + +TEST(ServiceInfoTests, IdentityChecks) { + ServiceInfo info; + info.v4_address = kAddressV4; + info.v6_address = kAddressV6; + info.port = kPort; + info.unique_id = kTestUniqueId; + info.protocol_version = kTestVersion; + info.capabilities = kCapabilitiesParsed; + info.status = kStatusParsed; + info.model_name = kModelName; + info.friendly_name = kFriendlyName; + ASSERT_TRUE(info.IsValid()); + discovery::DnsSdInstanceRecord converted_record = + ServiceInfoToDnsSdRecord(info); + ErrorOr<ServiceInfo> identity_info = + DnsSdRecordToServiceInfo(converted_record); + ASSERT_TRUE(identity_info.is_value()); + EXPECT_EQ(identity_info.value(), info); + + discovery::DnsSdTxtRecord txt = CreateValidTxt(); + txt.SetValue(kCapabilitiesId, kCapabilitiesString); + discovery::DnsSdInstanceRecord record(kInstanceId, kCastV2ServiceId, + kCastV2DomainId, kEndpointV4, + kEndpointV6, txt); + ErrorOr<ServiceInfo> converted_info = DnsSdRecordToServiceInfo(record); + ASSERT_TRUE(converted_info.is_value()); + ASSERT_TRUE(converted_info.value().IsValid()); + discovery::DnsSdInstanceRecord identity_record = + ServiceInfoToDnsSdRecord(converted_info.value()); + EXPECT_EQ(identity_record, record); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc new file mode 100644 index 00000000000..4d39041cf4f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc @@ -0,0 +1,56 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/public/testing/discovery_utils.h" + +#include <sstream> + +#include "discovery/dnssd/public/dns_sd_instance_record.h" +#include "util/stringprintf.h" + +namespace openscreen { +namespace cast { + +discovery::DnsSdTxtRecord CreateValidTxt() { + discovery::DnsSdTxtRecord txt; + txt.SetValue(kUniqueIdKey, kTestUniqueId); + txt.SetValue(kVersionId, kTestVersion); + txt.SetValue(kCapabilitiesId, kCapabilitiesStringLong); + txt.SetValue(kStatusId, kStatus); + txt.SetValue(kFriendlyNameId, kFriendlyName); + txt.SetValue(kModelNameId, kModelName); + return txt; +} + +void CompareTxtString(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + const std::string& expected) { + ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key); + ASSERT_FALSE(value.is_error()) + << "expected value: '" << expected << "' for key: '" << key + << "'; got error: " << value.error(); + const std::vector<uint8_t>& data = value.value().get(); + std::string parsed_value = std::string(data.begin(), data.end()); + EXPECT_EQ(parsed_value, expected) << "expected value '" + << "' for key: '" << key << "'"; +} + +void CompareTxtInt(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + uint8_t expected) { + ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key); + ASSERT_FALSE(value.is_error()) + << "key: '" << key << "'' expected: '" << expected << "'"; + const std::vector<uint8_t>& data = value.value().get(); + std::string parsed_value = HexEncode(data); + ASSERT_EQ(data.size(), size_t{1}) + << "expected one byte value for key: '" << key << "' got size: '" + << data.size() << "' bytes"; + EXPECT_EQ(data[0], expected) + << "expected :" << std::hex << expected << "for key: '" << key + << "', got value: '" << parsed_value << "'"; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h new file mode 100644 index 00000000000..bea238500f4 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h @@ -0,0 +1,48 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_ +#define CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_ + +#include "cast/common/public/service_info.h" +#include "discovery/dnssd/public/dns_sd_txt_record.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/base/ip_address.h" + +namespace openscreen { +namespace cast { + +// Constants used for testing. +static const IPAddress kAddressV4(192, 168, 0, 0); +static const IPAddress kAddressV6(1, 2, 3, 4, 5, 6, 7, 8); +static constexpr uint16_t kPort = 80; +static const IPEndpoint kEndpointV4{kAddressV4, kPort}; +static const IPEndpoint kEndpointV6{kAddressV6, kPort}; +static constexpr char kTestUniqueId[] = "1234"; +static constexpr char kFriendlyName[] = "Friendly Name 123"; +static constexpr char kModelName[] = "Openscreen"; +static constexpr char kInstanceId[] = "Openscreen-1234"; +static constexpr uint8_t kTestVersion = 0; +static constexpr char kCapabilitiesString[] = "3"; +static constexpr char kCapabilitiesStringLong[] = "000003"; +static constexpr ReceiverCapabilities kCapabilitiesParsed = + static_cast<ReceiverCapabilities>(0x03); +static constexpr uint8_t kStatus = 0x01; +static constexpr ReceiverStatus kStatusParsed = ReceiverStatus::kBusy; + +discovery::DnsSdTxtRecord CreateValidTxt(); + +void CompareTxtString(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + const std::string& expected); + +void CompareTxtInt(const discovery::DnsSdTxtRecord& txt, + const std::string& key, + uint8_t expected); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_ diff --git a/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn b/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn new file mode 100644 index 00000000000..f62bf78db77 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn @@ -0,0 +1,66 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("channel") { + sources = [ + "channel/device_auth_namespace_handler.cc", + "channel/device_auth_namespace_handler.h", + "channel/message_util.cc", + "channel/message_util.h", + "channel/receiver_socket_factory.cc", + "channel/receiver_socket_factory.h", + ] + + public_deps = [ + "../../platform", + "../../third_party/abseil", + "../../third_party/boringssl", + "../common:channel", + "../common/channel/proto:channel_proto", + ] + + deps = [ + "../../util", + "../common:certificate", + ] +} + +source_set("test_helpers") { + testonly = true + sources = [ + "channel/testing/device_auth_test_helpers.cc", + "channel/testing/device_auth_test_helpers.h", + ] + + public_deps = [ + ":channel", + "../../third_party/boringssl", + "../common:test_helpers", + ] + deps = [ + "../../third_party/googletest:gtest", + "../common/channel/proto:channel_proto", + ] +} + +source_set("unittests") { + testonly = true + sources = [ + "channel/device_auth_namespace_handler_unittest.cc", + ] + + deps = [ + ":channel", + ":test_helpers", + "../../testing/util", + "../../third_party/googletest:gmock", + "../../third_party/googletest:gtest", + "../common:channel", + "../common/channel/proto:channel_proto", + ] + + data = [ + "../../test/data/cast/receiver/channel", + ] +} diff --git a/chromium/third_party/openscreen/src/cast/receiver/DEPS b/chromium/third_party/openscreen/src/cast/receiver/DEPS index 7bdadde7bd7..a2def1b0c42 100644 --- a/chromium/third_party/openscreen/src/cast/receiver/DEPS +++ b/chromium/third_party/openscreen/src/cast/receiver/DEPS @@ -2,6 +2,6 @@ include_rules = [ # libcast receiver code must not depend on the sender. - '+cast/common/public', + '+cast/common', '+cast/receiver' ] diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc new file mode 100644 index 00000000000..9e9d5e3f170 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc @@ -0,0 +1,155 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/receiver/channel/device_auth_namespace_handler.h" + +#include <openssl/evp.h> + +#include "cast/common/certificate/cast_cert_validator.h" +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/virtual_connection.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "platform/base/tls_credentials.h" +#include "util/crypto/digest_sign.h" + +using ::cast::channel::AuthChallenge; +using ::cast::channel::AuthError; +using ::cast::channel::AuthResponse; +using ::cast::channel::CastMessage; +using ::cast::channel::DeviceAuthMessage; +using ::cast::channel::HashAlgorithm; +using ::cast::channel::SignatureAlgorithm; + +namespace openscreen { +namespace cast { + +namespace { + +CastMessage GenerateErrorMessage(AuthError::ErrorType error_type) { + DeviceAuthMessage message; + AuthError* error = message.mutable_error(); + error->set_error_type(error_type); + std::string payload; + message.SerializeToString(&payload); + + CastMessage response; + response.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + response.set_namespace_(kAuthNamespace); + response.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY); + response.set_payload_binary(std::move(payload)); + return response; +} + +} // namespace + +DeviceAuthNamespaceHandler::DeviceAuthNamespaceHandler( + CredentialsProvider* creds_provider) + : creds_provider_(creds_provider) {} + +DeviceAuthNamespaceHandler::~DeviceAuthNamespaceHandler() = default; + +void DeviceAuthNamespaceHandler::OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + CastMessage message) { + if (message.payload_type() != + ::cast::channel::CastMessage_PayloadType_BINARY) { + return; + } + const std::string& payload = message.payload_binary(); + DeviceAuthMessage device_auth_message; + if (!device_auth_message.ParseFromArray(payload.data(), payload.length())) { + // TODO(btolsch): Consider all of these cases for future error reporting + // mechanism. + return; + } + + if (!device_auth_message.has_challenge()) { + return; + } + + if (device_auth_message.has_response() || device_auth_message.has_error()) { + return; + } + + const VirtualConnection virtual_conn{ + message.destination_id(), message.source_id(), socket->socket_id()}; + const AuthChallenge& challenge = device_auth_message.challenge(); + const SignatureAlgorithm sig_alg = challenge.signature_algorithm(); + HashAlgorithm hash_alg = challenge.hash_algorithm(); + // TODO(btolsch): Reconsider supporting SHA1 after further metrics + // investigation. + if ((sig_alg != ::cast::channel::UNSPECIFIED && + sig_alg != ::cast::channel::RSASSA_PKCS1v15) || + (hash_alg != ::cast::channel::SHA1 && + hash_alg != ::cast::channel::SHA256)) { + router->SendMessage( + virtual_conn, + GenerateErrorMessage(AuthError::SIGNATURE_ALGORITHM_UNAVAILABLE)); + return; + } + const EVP_MD* digest = + hash_alg == ::cast::channel::SHA256 ? EVP_sha256() : EVP_sha1(); + + const absl::Span<const uint8_t> tls_cert_der = + creds_provider_->GetCurrentTlsCertAsDer(); + const DeviceCredentials& device_creds = + creds_provider_->GetCurrentDeviceCredentials(); + if (tls_cert_der.empty() || device_creds.certs.empty() || + !device_creds.private_key) { + // TODO(btolsch): Add this to future error reporting. + router->SendMessage(virtual_conn, + GenerateErrorMessage(AuthError::INTERNAL_ERROR)); + return; + } + + std::unique_ptr<AuthResponse> auth_response(new AuthResponse()); + auth_response->set_client_auth_certificate(device_creds.certs[0]); + for (auto it = device_creds.certs.begin() + 1; it != device_creds.certs.end(); + ++it) { + auth_response->add_intermediate_certificate(*it); + } + auth_response->set_signature_algorithm(::cast::channel::RSASSA_PKCS1v15); + auth_response->set_hash_algorithm(hash_alg); + std::string sender_nonce; + if (challenge.has_sender_nonce()) { + sender_nonce = challenge.sender_nonce(); + auth_response->set_sender_nonce(sender_nonce); + } + + auth_response->set_crl(device_creds.serialized_crl); + + std::vector<uint8_t> to_be_signed; + to_be_signed.reserve(sender_nonce.size() + tls_cert_der.size()); + to_be_signed.insert(to_be_signed.end(), sender_nonce.begin(), + sender_nonce.end()); + to_be_signed.insert(to_be_signed.end(), tls_cert_der.begin(), + tls_cert_der.end()); + + ErrorOr<std::string> signature = + SignData(digest, device_creds.private_key.get(), to_be_signed); + if (!signature) { + router->SendMessage(virtual_conn, + GenerateErrorMessage(AuthError::INTERNAL_ERROR)); + return; + } + auth_response->set_signature(std::move(signature.value())); + + DeviceAuthMessage response_auth_message; + response_auth_message.set_allocated_response(auth_response.release()); + + std::string response_string; + response_auth_message.SerializeToString(&response_string); + CastMessage response; + response.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + response.set_namespace_(kAuthNamespace); + response.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY); + response.set_payload_binary(std::move(response_string)); + router->SendMessage(virtual_conn, std::move(response)); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h new file mode 100644 index 00000000000..ec918e33f16 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h @@ -0,0 +1,57 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_ +#define CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_ + +#include <openssl/evp.h> + +#include <string> +#include <vector> + +#include "absl/types/span.h" +#include "cast/common/channel/cast_message_handler.h" + +namespace openscreen { +namespace cast { + +struct DeviceCredentials { + // The device's certificate chain in DER form, where |certs[0]| is the + // device's certificate and |certs[certs.size()-1]| is the last intermediate + // before a Cast root certificate. + std::vector<std::string> certs; + + // The device's private key that corresponds to the certificate in |certs[0]|. + bssl::UniquePtr<EVP_PKEY> private_key; + + // If non-empty, this contains a serialized CrlBundle protobuf. This may be + // used by the sender as part of verifying |certs|. + std::string serialized_crl; +}; + +class DeviceAuthNamespaceHandler final : public CastMessageHandler { + public: + class CredentialsProvider { + public: + virtual absl::Span<const uint8_t> GetCurrentTlsCertAsDer() = 0; + virtual const DeviceCredentials& GetCurrentDeviceCredentials() = 0; + }; + + // |creds_provider| must outlive |this|. + explicit DeviceAuthNamespaceHandler(CredentialsProvider* creds_provider); + ~DeviceAuthNamespaceHandler(); + + // CastMessageHandler overrides. + void OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) override; + + private: + CredentialsProvider* const creds_provider_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc new file mode 100644 index 00000000000..698cbd541d5 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc @@ -0,0 +1,219 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/receiver/channel/device_auth_namespace_handler.h" + +#include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_socket_error_handler.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/receiver/channel/testing/device_auth_test_helpers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "testing/util/read_file.h" + +namespace openscreen { +namespace cast { +namespace { + +using ::cast::channel::AuthResponse; +using ::cast::channel::CastMessage; +using ::cast::channel::DeviceAuthMessage; +using ::cast::channel::SignatureAlgorithm; + +using ::testing::_; +using ::testing::ElementsAreArray; +using ::testing::Invoke; + +class DeviceAuthNamespaceHandlerTest : public ::testing::Test { + public: + void SetUp() override { + socket_ = fake_cast_socket_pair_.socket.get(); + router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.socket)); + router_.AddHandlerForLocalId(kPlatformReceiverId, &auth_handler_); + } + + protected: + FakeCastSocketPair fake_cast_socket_pair_; + MockSocketErrorHandler mock_error_handler_; + CastSocket* socket_; + + StaticCredentialsProvider creds_; + VirtualConnectionManager manager_; + VirtualConnectionRouter router_{&manager_}; + DeviceAuthNamespaceHandler auth_handler_{&creds_}; +}; + +#define TEST_DATA_PREFIX OPENSCREEN_TEST_DATA_DIR "cast/receiver/channel/" + +// The tests in this file use a pre-recorded AuthChallenge as input and a +// matching pre-recorded AuthResponse for verification. This is to make it +// easier to keep sender and receiver code separate, because the code that would +// really generate an AuthChallenge and verify an AuthResponse is under +// //cast/sender. The pre-recorded messages come from an integration test which +// _does_ properly call both sender and receiver sides, but can optionally +// record the messages for use in these unit tests. That test is currently +// under //cast/test. See //cast/test/README.md for more information. +// +// The tests generally follow this procedure: +// 1. Read a fake device certificate chain + TLS certificate from disk. +// 2. Read a pre-recorded CastMessage proto containing an AuthChallenge. +// 3. Send this CastMessage over a CastSocket to a DeviceAuthNamespaceHandler. +// 4. Catch the CastMessage response and check that it has an AuthResponse. +// 5. Check the AuthResponse against another pre-recorded protobuf. + +TEST_F(DeviceAuthNamespaceHandlerTest, AuthResponse) { + InitStaticCredentialsFromFiles( + &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem", + TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem"); + + // Send an auth challenge. |auth_handler_| will automatically respond via + // |router_| and we will catch the result in |challenge_reply|. + CastMessage auth_challenge; + const std::string auth_challenge_string = + ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb"); + ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string)); + + CastMessage challenge_reply; + EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) + .WillOnce( + Invoke([&challenge_reply](CastSocket* socket, CastMessage message) { + challenge_reply = std::move(message); + })); + ASSERT_TRUE( + fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge)) + .ok()); + + const std::string auth_response_string = + ReadEntireFileToString(TEST_DATA_PREFIX "auth_response.pb"); + AuthResponse expected_auth_response; + ASSERT_TRUE(expected_auth_response.ParseFromString(auth_response_string)); + + DeviceAuthMessage auth_message; + ASSERT_EQ(challenge_reply.payload_type(), + ::cast::channel::CastMessage_PayloadType_BINARY); + ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary())); + ASSERT_TRUE(auth_message.has_response()); + ASSERT_FALSE(auth_message.has_challenge()); + ASSERT_FALSE(auth_message.has_error()); + const AuthResponse& auth_response = auth_message.response(); + + EXPECT_EQ(expected_auth_response.signature(), auth_response.signature()); + EXPECT_EQ(expected_auth_response.client_auth_certificate(), + auth_response.client_auth_certificate()); + EXPECT_EQ(expected_auth_response.signature_algorithm(), + auth_response.signature_algorithm()); + EXPECT_EQ(expected_auth_response.sender_nonce(), + auth_response.sender_nonce()); + EXPECT_EQ(expected_auth_response.hash_algorithm(), + auth_response.hash_algorithm()); + EXPECT_EQ(expected_auth_response.crl(), auth_response.crl()); + EXPECT_THAT( + auth_response.intermediate_certificate(), + ElementsAreArray(expected_auth_response.intermediate_certificate())); +} + +TEST_F(DeviceAuthNamespaceHandlerTest, BadNonce) { + InitStaticCredentialsFromFiles( + &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem", + TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem"); + + // Send an auth challenge. |auth_handler_| will automatically respond via + // |router_| and we will catch the result in |challenge_reply|. + CastMessage auth_challenge; + const std::string auth_challenge_string = + ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb"); + ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string)); + + // Change the nonce to be different from what was used to record the correct + // response originally. + DeviceAuthMessage msg; + ASSERT_EQ(auth_challenge.payload_type(), + ::cast::channel::CastMessage_PayloadType_BINARY); + ASSERT_TRUE(msg.ParseFromString(auth_challenge.payload_binary())); + ASSERT_TRUE(msg.has_challenge()); + std::string* nonce = msg.mutable_challenge()->mutable_sender_nonce(); + (*nonce)[0] = ~(*nonce)[0]; + std::string new_payload; + ASSERT_TRUE(msg.SerializeToString(&new_payload)); + auth_challenge.set_payload_binary(new_payload); + + CastMessage challenge_reply; + EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) + .WillOnce( + Invoke([&challenge_reply](CastSocket* socket, CastMessage message) { + challenge_reply = std::move(message); + })); + ASSERT_TRUE( + fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge)) + .ok()); + + const std::string auth_response_string = + ReadEntireFileToString(TEST_DATA_PREFIX "auth_response.pb"); + AuthResponse expected_auth_response; + ASSERT_TRUE(expected_auth_response.ParseFromString(auth_response_string)); + + DeviceAuthMessage auth_message; + ASSERT_EQ(challenge_reply.payload_type(), + ::cast::channel::CastMessage_PayloadType_BINARY); + ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary())); + ASSERT_TRUE(auth_message.has_response()); + ASSERT_FALSE(auth_message.has_challenge()); + ASSERT_FALSE(auth_message.has_error()); + const AuthResponse& auth_response = auth_message.response(); + + // NOTE: This is the ultimate result of the nonce-mismatch. + EXPECT_NE(expected_auth_response.signature(), auth_response.signature()); +} + +TEST_F(DeviceAuthNamespaceHandlerTest, UnsupportedSignatureAlgorithm) { + InitStaticCredentialsFromFiles( + &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem", + TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem"); + + // Send an auth challenge. |auth_handler_| will automatically respond via + // |router_| and we will catch the result in |challenge_reply|. + CastMessage auth_challenge; + const std::string auth_challenge_string = + ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb"); + ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string)); + + // Change the signature algorithm an unsupported value. + DeviceAuthMessage msg; + ASSERT_EQ(auth_challenge.payload_type(), + ::cast::channel::CastMessage_PayloadType_BINARY); + ASSERT_TRUE(msg.ParseFromString(auth_challenge.payload_binary())); + ASSERT_TRUE(msg.has_challenge()); + msg.mutable_challenge()->set_signature_algorithm( + SignatureAlgorithm::RSASSA_PSS); + std::string new_payload; + ASSERT_TRUE(msg.SerializeToString(&new_payload)); + auth_challenge.set_payload_binary(new_payload); + + CastMessage challenge_reply; + EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) + .WillOnce( + Invoke([&challenge_reply](CastSocket* socket, CastMessage message) { + challenge_reply = std::move(message); + })); + ASSERT_TRUE( + fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge)) + .ok()); + + DeviceAuthMessage auth_message; + ASSERT_EQ(challenge_reply.payload_type(), + ::cast::channel::CastMessage_PayloadType_BINARY); + ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary())); + ASSERT_FALSE(auth_message.has_response()); + ASSERT_FALSE(auth_message.has_challenge()); + ASSERT_TRUE(auth_message.has_error()); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc new file mode 100644 index 00000000000..f011c21182d --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc @@ -0,0 +1,65 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/receiver/channel/message_util.h" + +#include "util/json/json_serialization.h" +#include "util/json/json_value.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +using ::cast::channel::CastMessage; + +namespace { + +ErrorOr<CastMessage> CreateAppAvailabilityResponse( + int request_id, + const std::string& sender_id, + const std::string& app_id, + AppAvailabilityResult availability_result) { + CastMessage availability_response; + Json::Value dict(Json::ValueType::objectValue); + dict[kMessageKeyRequestId] = request_id; + Json::Value availability(Json::ValueType::objectValue); + availability[app_id.c_str()] = + availability_result == AppAvailabilityResult::kAvailable + ? kMessageValueAppAvailable + : kMessageValueAppUnavailable; + dict[kMessageKeyAvailability] = std::move(availability); + ErrorOr<std::string> serialized = json::Stringify(dict); + if (!serialized) { + return Error::Code::kJsonWriteError; + } + + availability_response.set_source_id(kPlatformReceiverId); + availability_response.set_destination_id(sender_id); + availability_response.set_namespace_(kReceiverNamespace); + availability_response.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + availability_response.set_payload_utf8(std::move(serialized.value())); + availability_response.set_payload_type( + ::cast::channel::CastMessage_PayloadType_STRING); + return availability_response; +} + +} // namespace + +ErrorOr<CastMessage> CreateAppAvailableResponse(int request_id, + const std::string& sender_id, + const std::string& app_id) { + return CreateAppAvailabilityResponse(request_id, sender_id, app_id, + AppAvailabilityResult::kAvailable); +} + +ErrorOr<CastMessage> CreateAppUnavailableResponse(int request_id, + const std::string& sender_id, + const std::string& app_id) { + return CreateAppAvailabilityResponse(request_id, sender_id, app_id, + AppAvailabilityResult::kUnavailable); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h new file mode 100644 index 00000000000..191ff1e913a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h @@ -0,0 +1,30 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_ +#define CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_ + +#include "cast/common/channel/message_util.h" +#include "cast/common/channel/proto/cast_channel.pb.h" +#include "platform/base/error.h" + +namespace openscreen { +namespace cast { + +// Creates a message that responds to a previous app availability request with +// ID |request_id| which declares |app_id| to have availability of either +// available or unavailable respectively. +ErrorOr<::cast::channel::CastMessage> CreateAppAvailableResponse( + int request_id, + const std::string& sender_id, + const std::string& app_id); +ErrorOr<::cast::channel::CastMessage> CreateAppUnavailableResponse( + int request_id, + const std::string& sender_id, + const std::string& app_id); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc new file mode 100644 index 00000000000..f0a642e743b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc @@ -0,0 +1,52 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/receiver/channel/receiver_socket_factory.h" + +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +ReceiverSocketFactory::ReceiverSocketFactory(Client* client, + CastSocket::Client* socket_client) + : client_(client), socket_client_(socket_client) { + OSP_DCHECK(client); + OSP_DCHECK(socket_client); +} + +ReceiverSocketFactory::~ReceiverSocketFactory() = default; + +void ReceiverSocketFactory::OnAccepted( + TlsConnectionFactory* factory, + std::vector<uint8_t> der_x509_peer_cert, + std::unique_ptr<TlsConnection> connection) { + IPEndpoint endpoint = connection->GetRemoteEndpoint(); + auto socket = + std::make_unique<CastSocket>(std::move(connection), socket_client_); + client_->OnConnected(this, endpoint, std::move(socket)); +} + +void ReceiverSocketFactory::OnConnected( + TlsConnectionFactory* factory, + std::vector<uint8_t> der_x509_peer_cert, + std::unique_ptr<TlsConnection> connection) { + OSP_NOTREACHED() << "This factory is accept-only."; +} + +void ReceiverSocketFactory::OnConnectionFailed( + TlsConnectionFactory* factory, + const IPEndpoint& remote_address) { + OSP_DVLOG << "Receiving connection from endpoint failed: " << remote_address; + client_->OnError(this, Error(Error::Code::kConnectionFailed, + "Accepting connection failed.")); +} + +void ReceiverSocketFactory::OnError(TlsConnectionFactory* factory, + Error error) { + client_->OnError(this, error); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h new file mode 100644 index 00000000000..d8bde8cad6b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h @@ -0,0 +1,51 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_ +#define CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_ + +#include <vector> + +#include "cast/common/channel/cast_socket.h" +#include "platform/api/tls_connection_factory.h" +#include "platform/base/ip_address.h" + +namespace openscreen { +namespace cast { + +class ReceiverSocketFactory final : public TlsConnectionFactory::Client { + public: + class Client { + public: + virtual void OnConnected(ReceiverSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr<CastSocket> socket) = 0; + virtual void OnError(ReceiverSocketFactory* factory, Error error) = 0; + }; + + // |client| and |socket_client| must outlive |this|. + // TODO(btolsch): Add TaskRunner argument just for sequence checking. + ReceiverSocketFactory(Client* client, CastSocket::Client* socket_client); + ~ReceiverSocketFactory(); + + // TlsConnectionFactory::Client overrides. + void OnAccepted(TlsConnectionFactory* factory, + std::vector<uint8_t> der_x509_peer_cert, + std::unique_ptr<TlsConnection> connection) override; + void OnConnected(TlsConnectionFactory* factory, + std::vector<uint8_t> der_x509_peer_cert, + std::unique_ptr<TlsConnection> connection) override; + void OnConnectionFailed(TlsConnectionFactory* factory, + const IPEndpoint& remote_address) override; + void OnError(TlsConnectionFactory* factory, Error error) override; + + private: + Client* const client_; + CastSocket::Client* const socket_client_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_ diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc new file mode 100644 index 00000000000..51d7ebaa32a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc @@ -0,0 +1,51 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/receiver/channel/testing/device_auth_test_helpers.h" + +#include "gtest/gtest.h" + +namespace openscreen { +namespace cast { + +void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds, + bssl::UniquePtr<X509>* parsed_cert, + TrustStore* fake_trust_store, + absl::string_view privkey_filename, + absl::string_view chain_filename, + absl::string_view tls_filename) { + auto private_key = testing::ReadKeyFromPemFile(privkey_filename); + ASSERT_TRUE(private_key); + std::vector<std::string> certs = + testing::ReadCertificatesFromPemFile(chain_filename); + ASSERT_GT(certs.size(), 1u); + + // Use the root of the chain as the trust store for the test. + auto* data = reinterpret_cast<const uint8_t*>(certs.back().data()); + auto fake_root = + bssl::UniquePtr<X509>(d2i_X509(nullptr, &data, certs.back().size())); + ASSERT_TRUE(fake_root); + certs.pop_back(); + if (fake_trust_store) { + fake_trust_store->certs.emplace_back(fake_root.release()); + } + + creds->device_creds = DeviceCredentials{ + std::move(certs), std::move(private_key), std::string()}; + + const std::vector<std::string> tls_cert = + testing::ReadCertificatesFromPemFile(tls_filename); + ASSERT_EQ(tls_cert.size(), 1u); + data = reinterpret_cast<const uint8_t*>(tls_cert[0].data()); + if (parsed_cert) { + *parsed_cert = + bssl::UniquePtr<X509>(d2i_X509(nullptr, &data, tls_cert[0].size())); + ASSERT_TRUE(*parsed_cert); + } + const auto* begin = reinterpret_cast<const uint8_t*>(tls_cert[0].data()); + creds->tls_cert_der.assign(begin, begin + tls_cert[0].size()); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h new file mode 100644 index 00000000000..65ddccf3b9e --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h @@ -0,0 +1,46 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_ +#define CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_ + +#include <openssl/x509.h> + +#include <vector> + +#include "absl/strings/string_view.h" +#include "cast/common/certificate/testing/test_helpers.h" +#include "cast/receiver/channel/device_auth_namespace_handler.h" + +namespace openscreen { +namespace cast { + +class StaticCredentialsProvider final + : public DeviceAuthNamespaceHandler::CredentialsProvider { + public: + StaticCredentialsProvider() = default; + ~StaticCredentialsProvider() = default; + + absl::Span<const uint8_t> GetCurrentTlsCertAsDer() override { + return absl::Span<uint8_t>(tls_cert_der); + } + const DeviceCredentials& GetCurrentDeviceCredentials() override { + return device_creds; + } + + DeviceCredentials device_creds; + std::vector<uint8_t> tls_cert_der; +}; + +void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds, + bssl::UniquePtr<X509>* parsed_cert, + TrustStore* fake_trust_store, + absl::string_view privkey_filename, + absl::string_view chain_filename, + absl::string_view tls_filename); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/BUILD.gn b/chromium/third_party/openscreen/src/cast/sender/BUILD.gn index 31c21215547..f500f766e86 100644 --- a/chromium/third_party/openscreen/src/cast/sender/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/sender/BUILD.gn @@ -13,7 +13,7 @@ source_set("channel") { ] deps = [ - "../../util", + "../common:channel", "../common/certificate/proto:certificate_proto", "../common/channel/proto:channel_proto", ] @@ -21,19 +21,75 @@ source_set("channel") { public_deps = [ "../../platform", "../../third_party/boringssl", + "../../util", + "../common:certificate", + "../common:channel", + ] +} + +source_set("sender") { + sources = [ + "cast_app_availability_tracker.cc", + "cast_app_availability_tracker.h", + "cast_app_discovery_service_impl.cc", + "cast_app_discovery_service_impl.h", + "cast_platform_client.cc", + "cast_platform_client.h", + "public/cast_app_discovery_service.cc", + "public/cast_app_discovery_service.h", + "public/cast_media_source.cc", + "public/cast_media_source.h", + ] + + public_deps = [ + ":channel", + "../../platform", + "../../third_party/abseil", + "../../util", + "../common:channel", + "../common:public", + ] +} + +source_set("test_helpers") { + testonly = true + sources = [ + "testing/test_helpers.cc", + "testing/test_helpers.h", + ] + + deps = [ + "../../third_party/googletest:gtest", + "../../util", + "../common:channel", + "../receiver:channel", + ] + + public_deps = [ + ":channel", ] } source_set("unittests") { testonly = true sources = [ + "cast_app_availability_tracker_unittest.cc", + "cast_app_discovery_service_impl_unittest.cc", + "cast_platform_client_unittest.cc", "channel/cast_auth_util_unittest.cc", ] deps = [ ":channel", + ":sender", + ":test_helpers", "../../platform", + "../../platform:test", + "../../testing/util", + "../../third_party/googletest:gmock", "../../third_party/googletest:gtest", + "../../util", + "../common:test_helpers", "../common/certificate/proto:certificate_proto", "../common/certificate/proto:certificate_unittest_proto", ] diff --git a/chromium/third_party/openscreen/src/cast/sender/DEPS b/chromium/third_party/openscreen/src/cast/sender/DEPS index 7ab7a51a926..48f4daeefbe 100644 --- a/chromium/third_party/openscreen/src/cast/sender/DEPS +++ b/chromium/third_party/openscreen/src/cast/sender/DEPS @@ -1,7 +1,7 @@ # -*- Mode: Python; -*- include_rules = [ - # libcast sender code must not depend on the receiver. - '+cast/common', - '+cast/sender' + # libcast sender code must not depend on the receiver. + '+cast/common', + '+cast/sender', ] diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc new file mode 100644 index 00000000000..6c21c799d29 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc @@ -0,0 +1,165 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_app_availability_tracker.h" + +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +CastAppAvailabilityTracker::CastAppAvailabilityTracker() = default; +CastAppAvailabilityTracker::~CastAppAvailabilityTracker() = default; + +std::vector<std::string> CastAppAvailabilityTracker::RegisterSource( + const CastMediaSource& source) { + if (registered_sources_.find(source.source_id()) != + registered_sources_.end()) { + return {}; + } + + registered_sources_.emplace(source.source_id(), source); + + std::vector<std::string> new_app_ids; + for (const std::string& app_id : source.app_ids()) { + if (++registration_count_by_app_id_[app_id] == 1) { + new_app_ids.push_back(app_id); + } + } + return new_app_ids; +} + +void CastAppAvailabilityTracker::UnregisterSource( + const CastMediaSource& source) { + UnregisterSource(source.source_id()); +} + +void CastAppAvailabilityTracker::UnregisterSource( + const std::string& source_id) { + auto it = registered_sources_.find(source_id); + if (it == registered_sources_.end()) { + return; + } + + for (const std::string& app_id : it->second.app_ids()) { + auto count_it = registration_count_by_app_id_.find(app_id); + OSP_DCHECK(count_it != registration_count_by_app_id_.end()); + if (--(count_it->second) == 0) { + registration_count_by_app_id_.erase(count_it); + } + } + + registered_sources_.erase(it); +} + +std::vector<CastMediaSource> CastAppAvailabilityTracker::UpdateAppAvailability( + const std::string& device_id, + const std::string& app_id, + AppAvailability availability) { + auto& availabilities = app_availabilities_[device_id]; + auto it = availabilities.find(app_id); + + AppAvailabilityResult old_availability = it == availabilities.end() + ? AppAvailabilityResult::kUnknown + : it->second.availability; + AppAvailabilityResult new_availability = availability.availability; + + // Updated if status changes from/to kAvailable. + bool updated = (old_availability == AppAvailabilityResult::kAvailable || + new_availability == AppAvailabilityResult::kAvailable) && + old_availability != new_availability; + availabilities[app_id] = availability; + + if (!updated) { + return {}; + } + + std::vector<CastMediaSource> affected_sources; + for (const auto& source : registered_sources_) { + if (source.second.ContainsAppId(app_id)) { + affected_sources.push_back(source.second); + } + } + return affected_sources; +} + +std::vector<CastMediaSource> CastAppAvailabilityTracker::RemoveResultsForDevice( + const std::string& device_id) { + auto affected_sources = GetSupportedSources(device_id); + app_availabilities_.erase(device_id); + return affected_sources; +} + +std::vector<CastMediaSource> CastAppAvailabilityTracker::GetSupportedSources( + const std::string& device_id) const { + auto it = app_availabilities_.find(device_id); + if (it == app_availabilities_.end()) { + return std::vector<CastMediaSource>(); + } + + // Find all app IDs that are available on the device. + std::vector<std::string> supported_app_ids; + for (const auto& availability : it->second) { + if (availability.second.availability == AppAvailabilityResult::kAvailable) { + supported_app_ids.push_back(availability.first); + } + } + + // Find all registered sources whose query results contain the device ID. + std::vector<CastMediaSource> sources; + for (const auto& source : registered_sources_) { + if (source.second.ContainsAnyAppIdFrom(supported_app_ids)) { + sources.push_back(source.second); + } + } + return sources; +} + +CastAppAvailabilityTracker::AppAvailability +CastAppAvailabilityTracker::GetAvailability(const std::string& device_id, + const std::string& app_id) const { + auto availabilities_it = app_availabilities_.find(device_id); + if (availabilities_it == app_availabilities_.end()) { + return {AppAvailabilityResult::kUnknown, Clock::time_point{}}; + } + + const auto& availability_map = availabilities_it->second; + auto availability_it = availability_map.find(app_id); + if (availability_it == availability_map.end()) { + return {AppAvailabilityResult::kUnknown, Clock::time_point{}}; + } + + return availability_it->second; +} + +std::vector<std::string> CastAppAvailabilityTracker::GetRegisteredApps() const { + std::vector<std::string> registered_apps; + for (const auto& app_ids_and_count : registration_count_by_app_id_) { + registered_apps.push_back(app_ids_and_count.first); + } + + return registered_apps; +} + +std::vector<std::string> CastAppAvailabilityTracker::GetAvailableDevices( + const CastMediaSource& source) const { + std::vector<std::string> device_ids; + // For each device, check if there is at least one available app in |source|. + for (const auto& availabilities : app_availabilities_) { + for (const std::string& app_id : source.app_ids()) { + const auto& availabilities_map = availabilities.second; + auto availability_it = availabilities_map.find(app_id); + if (availability_it != availabilities_map.end() && + availability_it->second.availability == + AppAvailabilityResult::kAvailable) { + device_ids.push_back(availabilities.first); + break; + } + } + } + return device_ids; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h new file mode 100644 index 00000000000..c0bded96334 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h @@ -0,0 +1,125 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_ +#define CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_ + +#include <map> +#include <string> +#include <vector> + +#include "cast/sender/channel/message_util.h" +#include "cast/sender/public/cast_media_source.h" +#include "platform/api/time.h" + +namespace openscreen { +namespace cast { + +// Tracks device queries and their extracted Cast app IDs and their +// availabilities on discovered devices. +// Example usage: +/// +// (1) A page is interested in a Cast URL (e.g. by creating a +// PresentationRequest with the URL) like "cast:foo". To register the source to +// be tracked: +// CastAppAvailabilityTracker tracker; +// auto source = CastMediaSource::From("cast:foo"); +// auto new_app_ids = tracker.RegisterSource(source.value()); +// +// (2) The set of app IDs returned by the tracker can then be used by the caller +// to send an app availability request to each of the discovered devices. +// +// (3) Once the caller knows the availability value for a (device, app) pair, it +// may inform the tracker to update its results: +// auto affected_sources = +// tracker.UpdateAppAvailability(device_id, app_id, {availability, now}); +// +// (4) The tracker returns a subset of discovered sources that were affected by +// the update. The caller can then call |GetAvailableDevices()| to get the +// updated results for each affected source. +// +// (5a): At any time, the caller may call |RemoveResultsForDevice()| to remove +// cached results pertaining to the device, when it detects that a device is +// removed or no longer valid. +// +// (5b): At any time, the caller may call |GetAvailableDevices()| (even before +// the source is registered) to determine if there are cached results available. +// TODO(crbug.com/openscreen/112): Device -> Receiver renaming. +class CastAppAvailabilityTracker { + public: + // The result of an app availability request and the time when it is obtained. + struct AppAvailability { + AppAvailabilityResult availability; + Clock::time_point time; + }; + + CastAppAvailabilityTracker(); + ~CastAppAvailabilityTracker(); + + CastAppAvailabilityTracker(const CastAppAvailabilityTracker&) = delete; + CastAppAvailabilityTracker& operator=(const CastAppAvailabilityTracker&) = + delete; + + // Registers |source| with the tracker. Returns a list of new app IDs that + // were previously not known to the tracker. + std::vector<std::string> RegisterSource(const CastMediaSource& source); + + // Unregisters the source given by |source| or |source_id| with the tracker. + void UnregisterSource(const std::string& source_id); + void UnregisterSource(const CastMediaSource& source); + + // Updates the availability of |app_id| on |device_id| to |availability|. + // Returns a list of registered CastMediaSources for which the set of + // available devices might have been updated by this call. The caller should + // call |GetAvailableDevices| with the returned CastMediaSources to get the + // updated lists. + std::vector<CastMediaSource> UpdateAppAvailability( + const std::string& device_id, + const std::string& app_id, + AppAvailability availability); + + // Removes all results associated with |device_id|, i.e. when the device + // becomes invalid. Returns a list of registered CastMediaSources for which + // the set of available devices might have been updated by this call. The + // caller should call |GetAvailableDevices| with the returned CastMediaSources + // to get the updated lists. + std::vector<CastMediaSource> RemoveResultsForDevice( + const std::string& device_id); + + // Returns a list of registered CastMediaSources supported by |device_id|. + std::vector<CastMediaSource> GetSupportedSources( + const std::string& device_id) const; + + // Returns the availability for |app_id| on |device_id| and the time at which + // the availability was determined. If availability is kUnknown, then the time + // may be null (e.g. if an availability request was never sent). + AppAvailability GetAvailability(const std::string& device_id, + const std::string& app_id) const; + + // Returns a list of registered app IDs. + std::vector<std::string> GetRegisteredApps() const; + + // Returns a list of device IDs compatible with |source|, using the current + // availability info. + std::vector<std::string> GetAvailableDevices( + const CastMediaSource& source) const; + + private: + // App ID to availability. + using AppAvailabilityMap = std::map<std::string, AppAvailability>; + + // Registered sources and corresponding CastMediaSources. + std::map<std::string, CastMediaSource> registered_sources_; + + // App IDs tracked and the number of registered sources containing them. + std::map<std::string, int> registration_count_by_app_id_; + + // IDs and app availabilities of known devices. + std::map<std::string, AppAvailabilityMap> app_availabilities_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc new file mode 100644 index 00000000000..b45d356365f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc @@ -0,0 +1,159 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_app_availability_tracker.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" + +namespace openscreen { +namespace cast { +namespace { + +using ::testing::UnorderedElementsAreArray; + +MATCHER_P(CastMediaSourcesEqual, expected, "") { + if (expected.size() != arg.size()) + return false; + return std::equal( + expected.begin(), expected.end(), arg.begin(), + [](const CastMediaSource& source1, const CastMediaSource& source2) { + return source1.source_id() == source2.source_id(); + }); +} + +} // namespace + +class CastAppAvailabilityTrackerTest : public ::testing::Test { + public: + CastAppAvailabilityTrackerTest() : clock_(Clock::now()) {} + ~CastAppAvailabilityTrackerTest() override = default; + + Clock::time_point Now() const { return clock_.now(); } + + protected: + FakeClock clock_; + CastAppAvailabilityTracker tracker_; +}; + +TEST_F(CastAppAvailabilityTrackerTest, RegisterSource) { + CastMediaSource source1("cast:AAA?clientId=1", {"AAA"}); + CastMediaSource source2("cast:AAA?clientId=2", {"AAA"}); + + std::vector<std::string> expected_app_ids = {"AAA"}; + EXPECT_EQ(expected_app_ids, tracker_.RegisterSource(source1)); + + EXPECT_EQ(std::vector<std::string>{}, tracker_.RegisterSource(source1)); + EXPECT_EQ(std::vector<std::string>{}, tracker_.RegisterSource(source2)); + + tracker_.UnregisterSource(source1); + tracker_.UnregisterSource(source2); + + EXPECT_EQ(expected_app_ids, tracker_.RegisterSource(source1)); + EXPECT_EQ(expected_app_ids, tracker_.GetRegisteredApps()); +} + +TEST_F(CastAppAvailabilityTrackerTest, RegisterSourceReturnsMultipleAppIds) { + CastMediaSource source1("urn:x-org.chromium.media:source:tab:1", + {"0F5096E8", "85CDB22F"}); + + // Mirorring app ids. + std::vector<std::string> expected_app_ids = {"0F5096E8", "85CDB22F"}; + EXPECT_THAT(tracker_.RegisterSource(source1), + UnorderedElementsAreArray(expected_app_ids)); + EXPECT_THAT(tracker_.GetRegisteredApps(), + UnorderedElementsAreArray(expected_app_ids)); +} + +TEST_F(CastAppAvailabilityTrackerTest, MultipleAppIdsAlreadyTrackingOne) { + // One of the mirroring app IDs. + CastMediaSource source1("cast:0F5096E8?clientId=123", {"0F5096E8"}); + + std::vector<std::string> new_app_ids = {"0F5096E8"}; + std::vector<std::string> registered_app_ids = {"0F5096E8"}; + EXPECT_EQ(new_app_ids, tracker_.RegisterSource(source1)); + EXPECT_EQ(registered_app_ids, tracker_.GetRegisteredApps()); + + CastMediaSource source2("urn:x-org.chromium.media:source:tab:1", + {"0F5096E8", "85CDB22F"}); + + new_app_ids = {"85CDB22F"}; + registered_app_ids = {"0F5096E8", "85CDB22F"}; + + EXPECT_EQ(new_app_ids, tracker_.RegisterSource(source2)); + EXPECT_THAT(tracker_.GetRegisteredApps(), + UnorderedElementsAreArray(registered_app_ids)); +} + +TEST_F(CastAppAvailabilityTrackerTest, UpdateAppAvailability) { + CastMediaSource source1("cast:AAA?clientId=1", {"AAA"}); + CastMediaSource source2("cast:AAA?clientId=2", {"AAA"}); + CastMediaSource source3("cast:BBB?clientId=3", {"BBB"}); + + tracker_.RegisterSource(source3); + + // |source3| not affected. + EXPECT_THAT( + tracker_.UpdateAppAvailability( + "deviceId1", "AAA", {AppAvailabilityResult::kAvailable, Now()}), + CastMediaSourcesEqual(std::vector<CastMediaSource>())); + + std::vector<std::string> devices_1 = {"deviceId1"}; + std::vector<std::string> devices_1_2 = {"deviceId1", "deviceId2"}; + std::vector<CastMediaSource> sources_1 = {source1}; + std::vector<CastMediaSource> sources_1_2 = {source1, source2}; + + // Tracker returns available devices even though sources aren't registered. + EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1)); + EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2)); + EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); + + tracker_.RegisterSource(source1); + // Only |source1| is registered for this app. + EXPECT_THAT( + tracker_.UpdateAppAvailability( + "deviceId2", "AAA", {AppAvailabilityResult::kAvailable, Now()}), + CastMediaSourcesEqual(sources_1)); + EXPECT_THAT(tracker_.GetAvailableDevices(source1), + UnorderedElementsAreArray(devices_1_2)); + EXPECT_THAT(tracker_.GetAvailableDevices(source2), + UnorderedElementsAreArray(devices_1_2)); + EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); + + tracker_.RegisterSource(source2); + EXPECT_THAT( + tracker_.UpdateAppAvailability( + "deviceId2", "AAA", {AppAvailabilityResult::kUnavailable, Now()}), + CastMediaSourcesEqual(sources_1_2)); + EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1)); + EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2)); + EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); +} + +TEST_F(CastAppAvailabilityTrackerTest, RemoveResultsForDevice) { + CastMediaSource source1("cast:AAA?clientId=1", {"AAA"}); + + tracker_.UpdateAppAvailability("deviceId1", "AAA", + {AppAvailabilityResult::kAvailable, Now()}); + EXPECT_EQ(AppAvailabilityResult::kAvailable, + tracker_.GetAvailability("deviceId1", "AAA").availability); + + std::vector<std::string> expected_device_ids = {"deviceId1"}; + EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1)); + + // Unrelated device ID. + tracker_.RemoveResultsForDevice("deviceId2"); + EXPECT_EQ(AppAvailabilityResult::kAvailable, + tracker_.GetAvailability("deviceId1", "AAA").availability); + EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1)); + + tracker_.RemoveResultsForDevice("deviceId1"); + EXPECT_EQ(AppAvailabilityResult::kUnknown, + tracker_.GetAvailability("deviceId1", "AAA").availability); + EXPECT_EQ(std::vector<std::string>{}, tracker_.GetAvailableDevices(source1)); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc new file mode 100644 index 00000000000..69028aacff7 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc @@ -0,0 +1,211 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_app_discovery_service_impl.h" + +#include <algorithm> +#include <chrono> + +#include "cast/sender/public/cast_media_source.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +// The minimum time that must elapse before an app availability result can be +// force refreshed. +static constexpr std::chrono::minutes kRefreshThreshold = + std::chrono::minutes(1); + +} // namespace + +CastAppDiscoveryServiceImpl::CastAppDiscoveryServiceImpl( + CastPlatformClient* platform_client, + ClockNowFunctionPtr clock) + : platform_client_(platform_client), clock_(clock), weak_factory_(this) { + OSP_DCHECK(platform_client_); + OSP_DCHECK(clock_); +} + +CastAppDiscoveryServiceImpl::~CastAppDiscoveryServiceImpl() { + OSP_CHECK_EQ(avail_queries_.size(), 0u); +} + +CastAppDiscoveryService::Subscription +CastAppDiscoveryServiceImpl::StartObservingAvailability( + const CastMediaSource& source, + AvailabilityCallback callback) { + const std::string& source_id = source.source_id(); + + // Return cached results immediately, if available. + std::vector<std::string> cached_device_ids = + availability_tracker_.GetAvailableDevices(source); + if (!cached_device_ids.empty()) { + callback(source, GetReceiversByIds(cached_device_ids)); + } + + auto& callbacks = avail_queries_[source_id]; + uint32_t query_id = GetNextAvailabilityQueryId(); + callbacks.push_back({query_id, std::move(callback)}); + if (callbacks.size() == 1) { + // NOTE: Even though we retain availability results for an app unregistered + // from the tracker, we will refresh the results when the app is + // re-registered. + std::vector<std::string> new_app_ids = + availability_tracker_.RegisterSource(source); + for (const auto& app_id : new_app_ids) { + for (const auto& entry : receivers_by_id_) { + RequestAppAvailability(entry.first, app_id); + } + } + } + + return MakeSubscription(this, query_id); +} + +void CastAppDiscoveryServiceImpl::Refresh() { + const auto app_ids = availability_tracker_.GetRegisteredApps(); + for (const auto& entry : receivers_by_id_) { + for (const auto& app_id : app_ids) { + RequestAppAvailability(entry.first, app_id); + } + } +} + +void CastAppDiscoveryServiceImpl::AddOrUpdateReceiver( + const ServiceInfo& receiver) { + const std::string& device_id = receiver.unique_id; + receivers_by_id_[device_id] = receiver; + + // Any queries that currently contain this receiver should be updated. + UpdateAvailabilityQueries( + availability_tracker_.GetSupportedSources(device_id)); + + for (const std::string& app_id : availability_tracker_.GetRegisteredApps()) { + RequestAppAvailability(device_id, app_id); + } +} + +void CastAppDiscoveryServiceImpl::RemoveReceiver(const ServiceInfo& receiver) { + const std::string& device_id = receiver.unique_id; + receivers_by_id_.erase(device_id); + UpdateAvailabilityQueries( + availability_tracker_.RemoveResultsForDevice(device_id)); +} + +void CastAppDiscoveryServiceImpl::RequestAppAvailability( + const std::string& device_id, + const std::string& app_id) { + if (ShouldRefreshAppAvailability(device_id, app_id, clock_())) { + platform_client_->RequestAppAvailability( + device_id, app_id, + [self = weak_factory_.GetWeakPtr(), device_id]( + const std::string& app_id, AppAvailabilityResult availability) { + if (self) { + self->UpdateAppAvailability(device_id, app_id, availability); + } + }); + } +} + +void CastAppDiscoveryServiceImpl::UpdateAppAvailability( + const std::string& device_id, + const std::string& app_id, + AppAvailabilityResult availability) { + if (receivers_by_id_.find(device_id) == receivers_by_id_.end()) { + return; + } + + OSP_DVLOG << "App " << app_id << " on receiver " << device_id << " is " + << ToString(availability); + + UpdateAvailabilityQueries(availability_tracker_.UpdateAppAvailability( + device_id, app_id, {availability, clock_()})); +} + +void CastAppDiscoveryServiceImpl::UpdateAvailabilityQueries( + const std::vector<CastMediaSource>& sources) { + for (const auto& source : sources) { + const std::string& source_id = source.source_id(); + auto it = avail_queries_.find(source_id); + if (it == avail_queries_.end()) + continue; + std::vector<std::string> device_ids = + availability_tracker_.GetAvailableDevices(source); + std::vector<ServiceInfo> receivers = GetReceiversByIds(device_ids); + for (const auto& callback : it->second) { + callback.callback(source, receivers); + } + } +} + +std::vector<ServiceInfo> CastAppDiscoveryServiceImpl::GetReceiversByIds( + const std::vector<std::string>& device_ids) const { + std::vector<ServiceInfo> receivers; + for (const std::string& device_id : device_ids) { + auto entry = receivers_by_id_.find(device_id); + if (entry != receivers_by_id_.end()) { + receivers.push_back(entry->second); + } + } + return receivers; +} + +bool CastAppDiscoveryServiceImpl::ShouldRefreshAppAvailability( + const std::string& device_id, + const std::string& app_id, + Clock::time_point now) const { + // TODO(btolsch): Consider an exponential backoff mechanism instead. + // Receivers will typically respond with "unavailable" immediately after boot + // and then become available 10-30 seconds later. + auto availability = availability_tracker_.GetAvailability(device_id, app_id); + switch (availability.availability) { + case AppAvailabilityResult::kAvailable: + return false; + case AppAvailabilityResult::kUnavailable: + return (now - availability.time) > kRefreshThreshold; + // TODO(btolsch): Should there be a background task for periodically + // refreshing kUnknown (or even kUnavailable) results? + case AppAvailabilityResult::kUnknown: + return true; + } + + OSP_NOTREACHED(); + return false; +} + +uint32_t CastAppDiscoveryServiceImpl::GetNextAvailabilityQueryId() { + if (free_query_ids_.empty()) { + return next_avail_query_id_++; + } else { + uint32_t id = free_query_ids_.back(); + free_query_ids_.pop_back(); + return id; + } +} + +void CastAppDiscoveryServiceImpl::RemoveAvailabilityCallback(uint32_t id) { + for (auto entry = avail_queries_.begin(); entry != avail_queries_.end(); + ++entry) { + const std::string& source_id = entry->first; + auto& callbacks = entry->second; + auto it = + std::find_if(callbacks.begin(), callbacks.end(), + [id](const AvailabilityCallbackEntry& callback_entry) { + return callback_entry.id == id; + }); + if (it != callbacks.end()) { + callbacks.erase(it); + if (callbacks.empty()) { + availability_tracker_.UnregisterSource(source_id); + avail_queries_.erase(entry); + } + return; + } + } +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h new file mode 100644 index 00000000000..4093311af35 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h @@ -0,0 +1,99 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_ +#define CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_ + +#include <map> +#include <string> +#include <vector> + +#include "cast/common/public/service_info.h" +#include "cast/sender/cast_app_availability_tracker.h" +#include "cast/sender/cast_platform_client.h" +#include "cast/sender/public/cast_app_discovery_service.h" +#include "platform/api/time.h" +#include "util/weak_ptr.h" + +namespace openscreen { +namespace cast { + +// Keeps track of availability queries, receives receiver updates, and issues +// app availability requests based on these signals. +class CastAppDiscoveryServiceImpl : public CastAppDiscoveryService { + public: + // |platform_client| must outlive |this|. + CastAppDiscoveryServiceImpl(CastPlatformClient* platform_client, + ClockNowFunctionPtr clock); + ~CastAppDiscoveryServiceImpl() override; + + // CastAppDiscoveryService implementation. + Subscription StartObservingAvailability( + const CastMediaSource& source, + AvailabilityCallback callback) override; + + // Reissues app availability requests for currently registered (device_id, + // app_id) pairs whose status is kUnavailable or kUnknown. + void Refresh() override; + + void AddOrUpdateReceiver(const ServiceInfo& receiver); + void RemoveReceiver(const ServiceInfo& receiver); + + private: + struct AvailabilityCallbackEntry { + uint32_t id; + AvailabilityCallback callback; + }; + + // Issues an app availability request for |app_id| to the receiver given by + // |device_id|. + void RequestAppAvailability(const std::string& device_id, + const std::string& app_id); + + // Updates the availability result for |device_id| and |app_id| with |result|, + // and notifies callbacks with updated availability query results. + void UpdateAppAvailability(const std::string& device_id, + const std::string& app_id, + AppAvailabilityResult result); + + // Updates the availability query results for |sources|. + void UpdateAvailabilityQueries(const std::vector<CastMediaSource>& sources); + + std::vector<ServiceInfo> GetReceiversByIds( + const std::vector<std::string>& device_ids) const; + + // Returns true if an app availability request should be issued for + // |device_id| and |app_id|. |now| is used for checking whether previously + // cached results should be refreshed. + bool ShouldRefreshAppAvailability(const std::string& device_id, + const std::string& app_id, + Clock::time_point now) const; + + uint32_t GetNextAvailabilityQueryId(); + + void RemoveAvailabilityCallback(uint32_t id) override; + + std::map<std::string, ServiceInfo> receivers_by_id_; + + // Registered availability queries and their associated callbacks keyed by + // media source IDs. + std::map<std::string, std::vector<AvailabilityCallbackEntry>> avail_queries_; + + // Callback ID tracking. + uint32_t next_avail_query_id_; + std::vector<uint32_t> free_query_ids_; + + CastPlatformClient* const platform_client_; + + CastAppAvailabilityTracker availability_tracker_; + + const ClockNowFunctionPtr clock_; + + WeakPtrFactory<CastAppDiscoveryServiceImpl> weak_factory_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc new file mode 100644 index 00000000000..863eb3472c4 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc @@ -0,0 +1,350 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_app_discovery_service_impl.h" + +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_socket_error_handler.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/common/public/service_info.h" +#include "cast/sender/testing/test_helpers.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +using ::cast::channel::CastMessage; + +using ::testing::_; + +class CastAppDiscoveryServiceImplTest : public ::testing::Test { + public: + void SetUp() override { + socket_id_ = fake_cast_socket_pair_.socket->socket_id(); + router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.socket)); + + receiver_.v4_address = fake_cast_socket_pair_.remote_endpoint.address; + receiver_.port = fake_cast_socket_pair_.remote_endpoint.port; + receiver_.unique_id = "deviceId1"; + receiver_.friendly_name = "Some Name"; + } + + protected: + CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; } + MockCastSocketClient& peer_client() { + return fake_cast_socket_pair_.mock_peer_client; + } + + void AddOrUpdateReceiver(const ServiceInfo& receiver, int32_t socket_id) { + platform_client_.AddOrUpdateReceiver(receiver, socket_id); + app_discovery_service_.AddOrUpdateReceiver(receiver); + } + + CastAppDiscoveryService::Subscription StartObservingAvailability( + const CastMediaSource& source, + std::vector<ServiceInfo>* save_receivers) { + return app_discovery_service_.StartObservingAvailability( + source, [save_receivers](const CastMediaSource& source, + const std::vector<ServiceInfo>& receivers) { + *save_receivers = receivers; + }); + } + + CastAppDiscoveryService::Subscription StartSourceA1Query( + std::vector<ServiceInfo>* receivers, + int* request_id, + std::string* sender_id) { + auto subscription = StartObservingAvailability(source_a_1_, receivers); + + // Adding a receiver after app registered causes app availability request to + // be sent. + *request_id = -1; + *sender_id = ""; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([request_id, sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", request_id, sender_id); + }); + + AddOrUpdateReceiver(receiver_, socket_id_); + + return subscription; + } + + FakeCastSocketPair fake_cast_socket_pair_; + int32_t socket_id_; + MockSocketErrorHandler mock_error_handler_; + VirtualConnectionManager manager_; + VirtualConnectionRouter router_{&manager_}; + FakeClock clock_{Clock::now()}; + FakeTaskRunner task_runner_{&clock_}; + CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now, + &task_runner_}; + CastAppDiscoveryServiceImpl app_discovery_service_{&platform_client_, + &FakeClock::now}; + + CastMediaSource source_a_1_{"cast:AAA?clientId=1", {"AAA"}}; + CastMediaSource source_a_2_{"cast:AAA?clientId=2", {"AAA"}}; + CastMediaSource source_b_1_{"cast:BBB?clientId=1", {"BBB"}}; + + ServiceInfo receiver_; +}; + +TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailability) { + std::vector<ServiceInfo> receivers1; + int request_id; + std::string sender_id; + auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); + + // Same app ID should not trigger another request. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + std::vector<ServiceInfo> receivers2; + auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + ASSERT_EQ(receivers1.size(), 1u); + ASSERT_EQ(receivers2.size(), 1u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers2[0].unique_id, "deviceId1"); + + // No more updates for |source_a_1_| (i.e. |receivers1|). + subscription1.Reset(); + platform_client_.RemoveReceiver(receiver_); + app_discovery_service_.RemoveReceiver(receiver_); + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers2.size(), 0u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); +} + +TEST_F(CastAppDiscoveryServiceImplTest, ReAddAvailQueryUsesCachedValue) { + std::vector<ServiceInfo> receivers1; + int request_id; + std::string sender_id; + auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + + subscription1.Reset(); + receivers1.clear(); + + // Request not re-sent; cached kAvailable value is used. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + subscription1 = StartObservingAvailability(source_a_1_, &receivers1); + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); +} + +TEST_F(CastAppDiscoveryServiceImplTest, AvailQueryUpdatedOnReceiverUpdate) { + std::vector<ServiceInfo> receivers1; + int request_id; + std::string sender_id; + auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); + + // Result set now includes |receiver_|. + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + + // Updating |receiver_| causes |source_a_1_| query to be updated, but it's too + // soon for a new message to be sent. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + receiver_.friendly_name = "New Name"; + + AddOrUpdateReceiver(receiver_, socket_id_); + + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers1[0].friendly_name, "New Name"); +} + +TEST_F(CastAppDiscoveryServiceImplTest, Refresh) { + std::vector<ServiceInfo> receivers1; + auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1); + std::vector<ServiceInfo> receivers2; + auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2); + + // Adding a receiver after app registered causes two separate app availability + // requests to be sent. + int request_idA = -1; + int request_idB = -1; + std::string sender_id = ""; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .Times(2) + .WillRepeatedly([&request_idA, &request_idB, &sender_id]( + CastSocket*, CastMessage message) { + std::string app_id; + int request_id = -1; + VerifyAppAvailabilityRequest(message, &app_id, &request_id, &sender_id); + if (app_id == "AAA") { + EXPECT_EQ(request_idA, -1); + request_idA = request_id; + } else if (app_id == "BBB") { + EXPECT_EQ(request_idB, -1); + request_idB = request_id; + } else { + EXPECT_TRUE(false); + } + }); + + AddOrUpdateReceiver(receiver_, socket_id_); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_idA, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + availability_response = + CreateAppUnavailableResponseChecked(request_idB, sender_id, "BBB"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + ASSERT_EQ(receivers1.size(), 1u); + ASSERT_EQ(receivers2.size(), 0u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + + // Not enough time has passed for a refresh. + clock_.Advance(std::chrono::seconds(30)); + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + app_discovery_service_.Refresh(); + + // Refresh will now query again for unavailable app IDs. + clock_.Advance(std::chrono::minutes(2)); + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_idB, &sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "BBB", &request_idB, &sender_id); + }); + app_discovery_service_.Refresh(); +} + +TEST_F(CastAppDiscoveryServiceImplTest, + StartObservingAvailabilityAfterReceiverAdded) { + // No registered apps. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + AddOrUpdateReceiver(receiver_, socket_id_); + + // Registering apps immediately sends requests to |receiver_|. + int request_idA = -1; + std::string sender_id = ""; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_idA, &sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", &request_idA, &sender_id); + }); + std::vector<ServiceInfo> receivers1; + auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1); + + int request_idB = -1; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_idB, &sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "BBB", &request_idB, &sender_id); + }); + std::vector<ServiceInfo> receivers2; + auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2); + + // Add a new receiver with a corresponding socket. + FakeCastSocketPair fake_sockets2({{192, 168, 1, 17}, 2345}, + {{192, 168, 1, 19}, 2345}); + CastSocket* socket2 = fake_sockets2.socket.get(); + router_.TakeSocket(&mock_error_handler_, std::move(fake_sockets2.socket)); + ServiceInfo receiver2; + receiver2.unique_id = "deviceId2"; + receiver2.v4_address = fake_sockets2.remote_endpoint.address; + receiver2.port = fake_sockets2.remote_endpoint.port; + + // Adding new receiver causes availability requests for both apps to be sent + // to the new receiver. + request_idA = -1; + request_idB = -1; + EXPECT_CALL(fake_sockets2.mock_peer_client, OnMessage(_, _)) + .Times(2) + .WillRepeatedly([&request_idA, &request_idB, &sender_id]( + CastSocket*, CastMessage message) { + std::string app_id; + int request_id = -1; + VerifyAppAvailabilityRequest(message, &app_id, &request_id, &sender_id); + if (app_id == "AAA") { + EXPECT_EQ(request_idA, -1); + request_idA = request_id; + } else if (app_id == "BBB") { + EXPECT_EQ(request_idB, -1); + request_idB = request_id; + } else { + EXPECT_TRUE(false); + } + }); + + AddOrUpdateReceiver(receiver2, socket2->socket_id()); +} + +TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailabilityCachedValue) { + std::vector<ServiceInfo> receivers1; + int request_id; + std::string sender_id; + auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + ASSERT_EQ(receivers1.size(), 1u); + EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + + // Same app ID should not trigger another request, but it should return + // cached value. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + std::vector<ServiceInfo> receivers2; + auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2); + ASSERT_EQ(receivers2.size(), 1u); + EXPECT_EQ(receivers2[0].unique_id, "deviceId1"); +} + +TEST_F(CastAppDiscoveryServiceImplTest, AvailabilityUnknownOrUnavailable) { + std::vector<ServiceInfo> receivers1; + int request_id; + std::string sender_id; + auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); + + // The request will timeout resulting in unknown app availability. + clock_.Advance(std::chrono::seconds(10)); + task_runner_.RunTasksUntilIdle(); + EXPECT_EQ(receivers1.size(), 0u); + + // Receiver updated together with unknown app availability will cause a + // request to be sent again. + request_id = -1; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_id, &sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id); + }); + AddOrUpdateReceiver(receiver_, socket_id_); + + CastMessage availability_response = + CreateAppUnavailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + + // Known availability so no request sent. + EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); + AddOrUpdateReceiver(receiver_, socket_id_); + + // Removing the receiver will also remove previous availability information. + // Next time the receiver is added, a new request will be sent. + platform_client_.RemoveReceiver(receiver_); + app_discovery_service_.RemoveReceiver(receiver_); + + request_id = -1; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_id, &sender_id](CastSocket*, CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id); + }); + + AddOrUpdateReceiver(receiver_, socket_id_); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc new file mode 100644 index 00000000000..6e3d13da4bf --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc @@ -0,0 +1,224 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_platform_client.h" + +#include <random> + +#include "absl/strings/str_cat.h" +#include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/common/public/service_info.h" +#include "util/json/json_serialization.h" +#include "util/logging.h" +#include "util/stringprintf.h" + +namespace openscreen { +namespace cast { + +static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5); + +namespace { + +std::string MakeRandomSenderId() { + static auto& rd = *new std::random_device(); + static auto& gen = *new std::mt19937(rd()); + static auto& dist = *new std::uniform_int_distribution<>(1, 1000000); + return absl::StrCat("sender-", dist(gen)); +} + +} // namespace + +CastPlatformClient::CastPlatformClient(VirtualConnectionRouter* router, + VirtualConnectionManager* manager, + ClockNowFunctionPtr clock, + TaskRunner* task_runner) + : sender_id_(MakeRandomSenderId()), + virtual_conn_router_(router), + virtual_conn_manager_(manager), + clock_(clock), + task_runner_(task_runner) { + OSP_DCHECK(virtual_conn_manager_); + OSP_DCHECK(clock_); + OSP_DCHECK(task_runner_); + virtual_conn_router_->AddHandlerForLocalId(sender_id_, this); +} + +CastPlatformClient::~CastPlatformClient() { + virtual_conn_router_->RemoveHandlerForLocalId(sender_id_); + + for (auto& pending_requests : pending_requests_by_device_id_) { + for (auto& avail_request : pending_requests.second.availability) { + avail_request.callback(avail_request.app_id, + AppAvailabilityResult::kUnknown); + } + } +} + +absl::optional<int> CastPlatformClient::RequestAppAvailability( + const std::string& device_id, + const std::string& app_id, + AppAvailabilityCallback callback) { + auto entry = socket_id_by_device_id_.find(device_id); + if (entry == socket_id_by_device_id_.end()) { + callback(app_id, AppAvailabilityResult::kUnknown); + return absl::nullopt; + } + int socket_id = entry->second; + + int request_id = GetNextRequestId(); + ErrorOr<::cast::channel::CastMessage> message = + CreateAppAvailabilityRequest(sender_id_, request_id, app_id); + OSP_DCHECK(message); + + PendingRequests& pending_requests = pending_requests_by_device_id_[device_id]; + auto timeout = std::make_unique<Alarm>(clock_, task_runner_); + timeout->ScheduleFromNow( + [this, request_id]() { CancelAppAvailabilityRequest(request_id); }, + kRequestTimeout); + pending_requests.availability.push_back(AvailabilityRequest{ + request_id, app_id, std::move(timeout), std::move(callback)}); + + VirtualConnection virtual_conn{sender_id_, kPlatformReceiverId, socket_id}; + if (!virtual_conn_manager_->GetConnectionData(virtual_conn)) { + virtual_conn_manager_->AddConnection(virtual_conn, + VirtualConnection::AssociatedData{}); + } + + virtual_conn_router_->SendMessage(std::move(virtual_conn), + std::move(message.value())); + + return request_id; +} + +void CastPlatformClient::AddOrUpdateReceiver(const ServiceInfo& device, + int socket_id) { + socket_id_by_device_id_[device.unique_id] = socket_id; +} + +void CastPlatformClient::RemoveReceiver(const ServiceInfo& device) { + auto pending_requests_it = + pending_requests_by_device_id_.find(device.unique_id); + if (pending_requests_it != pending_requests_by_device_id_.end()) { + for (const AvailabilityRequest& availability : + pending_requests_it->second.availability) { + availability.callback(availability.app_id, + AppAvailabilityResult::kUnknown); + } + pending_requests_by_device_id_.erase(pending_requests_it); + } + socket_id_by_device_id_.erase(device.unique_id); +} + +void CastPlatformClient::CancelRequest(int request_id) { + for (auto entry = pending_requests_by_device_id_.begin(); + entry != pending_requests_by_device_id_.end(); ++entry) { + auto& pending_requests = entry->second; + auto it = std::find_if(pending_requests.availability.begin(), + pending_requests.availability.end(), + [request_id](const AvailabilityRequest& request) { + return request.request_id == request_id; + }); + if (it != pending_requests.availability.end()) { + pending_requests.availability.erase(it); + break; + } + } +} + +void CastPlatformClient::OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) { + if (message.payload_type() != + ::cast::channel::CastMessage_PayloadType_STRING || + message.namespace_() != kReceiverNamespace || + message.source_id() != kPlatformReceiverId) { + return; + } + ErrorOr<Json::Value> dict_or_error = json::Parse(message.payload_utf8()); + if (dict_or_error.is_error()) { + OSP_DVLOG << "Failed to deserialize CastMessage payload."; + return; + } + + Json::Value& dict = dict_or_error.value(); + absl::optional<int> request_id = + MaybeGetInt(dict, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyRequestId)); + if (request_id) { + auto entry = std::find_if( + socket_id_by_device_id_.begin(), socket_id_by_device_id_.end(), + [socket](const std::pair<std::string, int>& entry) { + return entry.second == socket->socket_id(); + }); + if (entry != socket_id_by_device_id_.end()) { + HandleResponse(entry->first, request_id.value(), dict); + } + } +} + +void CastPlatformClient::HandleResponse(const std::string& device_id, + int request_id, + const Json::Value& message) { + auto entry = pending_requests_by_device_id_.find(device_id); + if (entry == pending_requests_by_device_id_.end()) { + return; + } + PendingRequests& pending_requests = entry->second; + auto it = std::find_if(pending_requests.availability.begin(), + pending_requests.availability.end(), + [request_id](const AvailabilityRequest& request) { + return request.request_id == request_id; + }); + if (it != pending_requests.availability.end()) { + // TODO(btolsch): Can all of this manual parsing/checking be cleaned up into + // a single parsing API along with other message handling? + const Json::Value* maybe_availability = + message.find(JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyAvailability)); + if (maybe_availability && maybe_availability->isObject()) { + absl::optional<absl::string_view> result = + MaybeGetString(*maybe_availability, &it->app_id[0], + &it->app_id[0] + it->app_id.size()); + if (result) { + AppAvailabilityResult availability_result = + AppAvailabilityResult::kUnknown; + if (result.value() == kMessageValueAppAvailable) { + availability_result = AppAvailabilityResult::kAvailable; + } else if (result.value() == kMessageValueAppUnavailable) { + availability_result = AppAvailabilityResult::kUnavailable; + } else { + OSP_DVLOG << "Invalid availability result: " << result.value(); + } + it->callback(it->app_id, availability_result); + } + } + pending_requests.availability.erase(it); + } +} + +void CastPlatformClient::CancelAppAvailabilityRequest(int request_id) { + for (auto& entry : pending_requests_by_device_id_) { + PendingRequests& pending_requests = entry.second; + auto it = std::find_if(pending_requests.availability.begin(), + pending_requests.availability.end(), + [request_id](const AvailabilityRequest& request) { + return request.request_id == request_id; + }); + if (it != pending_requests.availability.end()) { + it->callback(it->app_id, AppAvailabilityResult::kUnknown); + pending_requests.availability.erase(it); + } + } +} + +// static +int CastPlatformClient::GetNextRequestId() { + return next_request_id_++; +} + +// static +int CastPlatformClient::next_request_id_ = 0; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h new file mode 100644 index 00000000000..41ad7fc704b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h @@ -0,0 +1,97 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_CAST_PLATFORM_CLIENT_H_ +#define CAST_SENDER_CAST_PLATFORM_CLIENT_H_ + +#include <functional> +#include <map> +#include <string> + +#include "absl/types/optional.h" +#include "cast/common/channel/cast_message_handler.h" +#include "cast/sender/channel/message_util.h" +#include "util/alarm.h" +#include "util/json/json_value.h" + +namespace openscreen { +namespace cast { + +struct ServiceInfo; +class VirtualConnectionManager; +class VirtualConnectionRouter; + +// This class handles Cast messages that generally relate to the "platform", in +// other words not a specific app currently running (e.g. app availability, +// receiver status). These messages follow a request/response format, so each +// request requires a corresponding response callback. These requests will also +// timeout if there is no response after a certain amount of time (currently 5 +// seconds). The timeout callbacks will be called on the thread managed by +// |task_runner|. +class CastPlatformClient final : public CastMessageHandler { + public: + using AppAvailabilityCallback = + std::function<void(const std::string& app_id, AppAvailabilityResult)>; + + CastPlatformClient(VirtualConnectionRouter* router, + VirtualConnectionManager* manager, + ClockNowFunctionPtr clock, + TaskRunner* task_runner); + ~CastPlatformClient() override; + + // Requests availability information for |app_id| from the receiver identified + // by |device_id|. |callback| will be called exactly once with a result. + absl::optional<int> RequestAppAvailability(const std::string& device_id, + const std::string& app_id, + AppAvailabilityCallback callback); + + // Notifies this object about general receiver connectivity or property + // changes. + void AddOrUpdateReceiver(const ServiceInfo& device, int socket_id); + void RemoveReceiver(const ServiceInfo& device); + + void CancelRequest(int request_id); + + private: + struct AvailabilityRequest { + int request_id; + std::string app_id; + std::unique_ptr<Alarm> timeout; + AppAvailabilityCallback callback; + }; + + struct PendingRequests { + std::vector<AvailabilityRequest> availability; + }; + + // CastMessageHandler overrides. + void OnMessage(VirtualConnectionRouter* router, + CastSocket* socket, + ::cast::channel::CastMessage message) override; + + void HandleResponse(const std::string& device_id, + int request_id, + const Json::Value& message); + + void CancelAppAvailabilityRequest(int request_id); + + static int GetNextRequestId(); + + static int next_request_id_; + + const std::string sender_id_; + VirtualConnectionRouter* const virtual_conn_router_; + VirtualConnectionManager* const virtual_conn_manager_; + std::map<std::string /* device_id */, int> socket_id_by_device_id_; + std::map<std::string /* device_id */, PendingRequests> + pending_requests_by_device_id_; + + const ClockNowFunctionPtr clock_; + TaskRunner* const task_runner_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_CAST_PLATFORM_CLIENT_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc new file mode 100644 index 00000000000..44e99660b06 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc @@ -0,0 +1,110 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/cast_platform_client.h" + +#include "cast/common/channel/testing/fake_cast_socket.h" +#include "cast/common/channel/testing/mock_socket_error_handler.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/common/public/service_info.h" +#include "cast/sender/testing/test_helpers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "util/json/json_serialization.h" +#include "util/json/json_value.h" + +namespace openscreen { +namespace cast { + +using ::cast::channel::CastMessage; + +using ::testing::_; + +class CastPlatformClientTest : public ::testing::Test { + public: + void SetUp() override { + socket_ = fake_cast_socket_pair_.socket.get(); + router_.TakeSocket(&mock_error_handler_, + std::move(fake_cast_socket_pair_.socket)); + + receiver_.v4_address = IPAddress{192, 168, 0, 17}; + receiver_.port = 4434; + receiver_.unique_id = "deviceId1"; + platform_client_.AddOrUpdateReceiver(receiver_, socket_->socket_id()); + } + + protected: + CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; } + MockCastSocketClient& peer_client() { + return fake_cast_socket_pair_.mock_peer_client; + } + + FakeCastSocketPair fake_cast_socket_pair_; + CastSocket* socket_ = nullptr; + MockSocketErrorHandler mock_error_handler_; + VirtualConnectionManager manager_; + VirtualConnectionRouter router_{&manager_}; + FakeClock clock_{Clock::now()}; + FakeTaskRunner task_runner_{&clock_}; + CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now, + &task_runner_}; + ServiceInfo receiver_; +}; + +TEST_F(CastPlatformClientTest, AppAvailability) { + int request_id = -1; + std::string sender_id; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_id, &sender_id](CastSocket* socket, + CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id); + }); + bool ran = false; + platform_client_.RequestAppAvailability( + "deviceId1", "AAA", + [&ran](const std::string& app_id, AppAvailabilityResult availability) { + EXPECT_EQ("AAA", app_id); + EXPECT_EQ(availability, AppAvailabilityResult::kAvailable); + ran = true; + }); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + EXPECT_TRUE(ran); + + // NOTE: Callback should only fire once, so it should not fire again here. + ran = false; + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); + EXPECT_FALSE(ran); +} + +TEST_F(CastPlatformClientTest, CancelRequest) { + int request_id = -1; + std::string sender_id; + EXPECT_CALL(peer_client(), OnMessage(_, _)) + .WillOnce([&request_id, &sender_id](CastSocket* socket, + CastMessage message) { + VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id); + }); + absl::optional<int> maybe_request_id = + platform_client_.RequestAppAvailability( + "deviceId1", "AAA", + [](const std::string& app_id, AppAvailabilityResult availability) { + EXPECT_TRUE(false); + }); + ASSERT_TRUE(maybe_request_id); + int local_request_id = maybe_request_id.value(); + platform_client_.CancelRequest(local_request_id); + + CastMessage availability_response = + CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); + EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok()); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc index 9f706259d0d..fc103e9bc90 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc +++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc @@ -6,17 +6,24 @@ #include <openssl/rand.h> -#include <vector> +#include <algorithm> #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/certificate/cast_cert_validator_internal.h" #include "cast/common/certificate/cast_crl.h" +#include "cast/common/channel/proto/cast_channel.pb.h" #include "platform/api/time.h" #include "platform/base/error.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace channel { + +using ::cast::channel::AuthResponse; +using ::cast::channel::CastMessage; +using ::cast::channel::DeviceAuthMessage; +using ::cast::channel::HashAlgorithm; + namespace { #define PARSE_ERROR_PREFIX "Failed to parse auth message: " @@ -30,40 +37,35 @@ const int kNonceSizeInBytes = 16; // The number of hours after which a nonce is regenerated. long kNonceExpirationTimeInHours = 24; -using CastCertError = openscreen::Error::Code; - // Extracts an embedded DeviceAuthMessage payload from an auth challenge reply // message. -openscreen::Error ParseAuthMessage(const CastMessage& challenge_reply, - DeviceAuthMessage* auth_message) { - if (challenge_reply.payload_type() != CastMessage_PayloadType_BINARY) { - return openscreen::Error(CastCertError::kCastV2WrongPayloadType, - PARSE_ERROR_PREFIX - "Wrong payload type in challenge reply"); +Error ParseAuthMessage(const CastMessage& challenge_reply, + DeviceAuthMessage* auth_message) { + if (challenge_reply.payload_type() != + ::cast::channel::CastMessage_PayloadType_BINARY) { + return Error(Error::Code::kCastV2WrongPayloadType, + PARSE_ERROR_PREFIX "Wrong payload type in challenge reply"); } if (!challenge_reply.has_payload_binary()) { - return openscreen::Error( - CastCertError::kCastV2NoPayload, PARSE_ERROR_PREFIX - "Payload type is binary but payload_binary field not set"); + return Error(Error::Code::kCastV2NoPayload, PARSE_ERROR_PREFIX + "Payload type is binary but payload_binary field not set"); } if (!auth_message->ParseFromString(challenge_reply.payload_binary())) { - return openscreen::Error( - CastCertError::kCastV2PayloadParsingFailed, PARSE_ERROR_PREFIX - "Cannot parse binary payload into DeviceAuthMessage"); + return Error(Error::Code::kCastV2PayloadParsingFailed, PARSE_ERROR_PREFIX + "Cannot parse binary payload into DeviceAuthMessage"); } if (auth_message->has_error()) { std::stringstream ss; ss << PARSE_ERROR_PREFIX "Auth message error: " << auth_message->error().error_type(); - return openscreen::Error(CastCertError::kCastV2MessageError, ss.str()); + return Error(Error::Code::kCastV2MessageError, ss.str()); } if (!auth_message->has_response()) { - return openscreen::Error(CastCertError::kCastV2NoResponse, - PARSE_ERROR_PREFIX - "Auth message has no response field"); + return Error(Error::Code::kCastV2NoResponse, + PARSE_ERROR_PREFIX "Auth message has no response field"); } - return openscreen::Error::None(); + return Error::None(); } class CastNonce { @@ -84,11 +86,11 @@ class CastNonce { OSP_CHECK_EQ( RAND_bytes(reinterpret_cast<uint8_t*>(&nonce_[0]), kNonceSizeInBytes), 1); - nonce_generation_time_ = openscreen::platform::GetWallTimeSinceUnixEpoch(); + nonce_generation_time_ = GetWallTimeSinceUnixEpoch(); } void EnsureNonceTimely() { - if (openscreen::platform::GetWallTimeSinceUnixEpoch() > + if (GetWallTimeSinceUnixEpoch() > (nonce_generation_time_ + std::chrono::hours(kNonceExpirationTimeInHours))) { GenerateNonce(); @@ -101,62 +103,60 @@ class CastNonce { std::chrono::seconds nonce_generation_time_; }; -// Maps CastCertError from certificate verification to openscreen::Error. +// Maps Error::Code from certificate verification to Error. // If crl_required is set to false, all revocation related errors are ignored. -openscreen::Error MapToOpenscreenError(CastCertError error, bool crl_required) { +Error MapToOpenscreenError(Error::Code error, bool crl_required) { switch (error) { - case CastCertError::kErrCertsMissing: - return openscreen::Error(CastCertError::kCastV2PeerCertEmpty, - "Failed to locate certificates."); - case CastCertError::kErrCertsParse: - return openscreen::Error(CastCertError::kErrCertsParse, - "Failed to parse certificates."); - case CastCertError::kErrCertsDateInvalid: - return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa, - "Failed date validity check."); - case CastCertError::kErrCertsVerifyGeneric: - return openscreen::Error( - CastCertError::kCastV2CertNotSignedByTrustedCa, - "Failed with a generic certificate verification error."); - case CastCertError::kErrCertsRestrictions: - return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa, - "Failed certificate restrictions."); - case CastCertError::kErrCrlInvalid: + case Error::Code::kErrCertsMissing: + return Error(Error::Code::kCastV2PeerCertEmpty, + "Failed to locate certificates."); + case Error::Code::kErrCertsParse: + return Error(Error::Code::kErrCertsParse, + "Failed to parse certificates."); + case Error::Code::kErrCertsDateInvalid: + return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, + "Failed date validity check."); + case Error::Code::kErrCertsVerifyGeneric: + return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, + "Failed with a generic certificate verification error."); + case Error::Code::kErrCertsRestrictions: + return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, + "Failed certificate restrictions."); + case Error::Code::kErrCrlInvalid: // This error is only encountered if |crl_required| is true. OSP_DCHECK(crl_required); - return openscreen::Error(CastCertError::kErrCrlInvalid, - "Failed to provide a valid CRL."); - case CastCertError::kErrCertsRevoked: - return openscreen::Error(CastCertError::kErrCertsRevoked, - "Failed certificate revocation check."); - case CastCertError::kNone: - return openscreen::Error::None(); + return Error(Error::Code::kErrCrlInvalid, + "Failed to provide a valid CRL."); + case Error::Code::kErrCertsRevoked: + return Error(Error::Code::kErrCertsRevoked, + "Failed certificate revocation check."); + case Error::Code::kNone: + return Error::None(); default: - return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa, - "Failed verifying cast device certificate."); + return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, + "Failed verifying cast device certificate."); } - return openscreen::Error::None(); + return Error::None(); } -openscreen::Error VerifyAndMapDigestAlgorithm( - HashAlgorithm response_digest_algorithm, - certificate::DigestAlgorithm* digest_algorithm, - bool enforce_sha256_checking) { +Error VerifyAndMapDigestAlgorithm(HashAlgorithm response_digest_algorithm, + DigestAlgorithm* digest_algorithm, + bool enforce_sha256_checking) { switch (response_digest_algorithm) { - case SHA1: + case ::cast::channel::SHA1: if (enforce_sha256_checking) { - return openscreen::Error(CastCertError::kCastV2DigestUnsupported, - "Unsupported digest algorithm."); + return Error(Error::Code::kCastV2DigestUnsupported, + "Unsupported digest algorithm."); } - *digest_algorithm = certificate::DigestAlgorithm::kSha1; + *digest_algorithm = DigestAlgorithm::kSha1; break; - case SHA256: - *digest_algorithm = certificate::DigestAlgorithm::kSha256; + case ::cast::channel::SHA256: + *digest_algorithm = DigestAlgorithm::kSha256; break; default: - return CastCertError::kCastV2DigestUnsupported; + return Error::Code::kCastV2DigestUnsupported; } - return openscreen::Error::None(); + return Error::None(); } } // namespace @@ -170,70 +170,80 @@ AuthContext::AuthContext(const std::string& nonce) : nonce_(nonce) {} AuthContext::~AuthContext() {} -openscreen::Error AuthContext::VerifySenderNonce( - const std::string& nonce_response, - bool enforce_nonce_checking) const { +Error AuthContext::VerifySenderNonce(const std::string& nonce_response, + bool enforce_nonce_checking) const { if (nonce_ != nonce_response) { if (enforce_nonce_checking) { - return openscreen::Error(CastCertError::kCastV2SenderNonceMismatch, - "Sender nonce mismatched."); + return Error(Error::Code::kCastV2SenderNonceMismatch, + "Sender nonce mismatched."); } } - return openscreen::Error::None(); + return Error::None(); } -openscreen::Error VerifyTLSCertificateValidity( - X509* peer_cert, - std::chrono::seconds verification_time) { +Error VerifyTLSCertificateValidity(X509* peer_cert, + std::chrono::seconds verification_time) { // Ensure the peer cert is valid and doesn't have an excessive remaining // lifetime. Although it is not verified as an X.509 certificate, the entire // structure is signed by the AuthResponse, so the validity field from X.509 // is repurposed as this signature's expiration. - certificate::DateTime not_before; - certificate::DateTime not_after; - if (!certificate::GetCertValidTimeRange(peer_cert, ¬_before, ¬_after)) { - return openscreen::Error(CastCertError::kErrCertsParse, PARSE_ERROR_PREFIX - "Parsing validity fields failed."); + DateTime not_before; + DateTime not_after; + if (!GetCertValidTimeRange(peer_cert, ¬_before, ¬_after)) { + return Error(Error::Code::kErrCertsParse, + PARSE_ERROR_PREFIX "Parsing validity fields failed."); } std::chrono::seconds lifetime_limit = verification_time + std::chrono::hours(24 * kMaxSelfSignedCertLifetimeInDays); - certificate::DateTime verification_time_exploded = {}; - certificate::DateTime lifetime_limit_exploded = {}; - OSP_CHECK(certificate::DateTimeFromSeconds(verification_time.count(), - &verification_time_exploded)); - OSP_CHECK(certificate::DateTimeFromSeconds(lifetime_limit.count(), - &lifetime_limit_exploded)); + DateTime verification_time_exploded = {}; + DateTime lifetime_limit_exploded = {}; + OSP_CHECK(DateTimeFromSeconds(verification_time.count(), + &verification_time_exploded)); + OSP_CHECK( + DateTimeFromSeconds(lifetime_limit.count(), &lifetime_limit_exploded)); if (verification_time_exploded < not_before) { - return openscreen::Error( - CastCertError::kCastV2TlsCertValidStartDateInFuture, - PARSE_ERROR_PREFIX "Certificate's valid start date is in the future."); + return Error(Error::Code::kCastV2TlsCertValidStartDateInFuture, + PARSE_ERROR_PREFIX + "Certificate's valid start date is in the future."); } if (not_after < verification_time_exploded) { - return openscreen::Error(CastCertError::kCastV2TlsCertExpired, - PARSE_ERROR_PREFIX "Certificate has expired."); + return Error(Error::Code::kCastV2TlsCertExpired, + PARSE_ERROR_PREFIX "Certificate has expired."); } if (lifetime_limit_exploded < not_after) { - return openscreen::Error(CastCertError::kCastV2TlsCertValidityPeriodTooLong, - PARSE_ERROR_PREFIX - "Peer cert lifetime is too long."); + return Error(Error::Code::kCastV2TlsCertValidityPeriodTooLong, + PARSE_ERROR_PREFIX "Peer cert lifetime is too long."); } - return openscreen::Error::None(); + return Error::None(); } -ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply( +ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl( + const AuthResponse& response, + const std::vector<uint8_t>& signature_input, + const CRLPolicy& crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time, + bool enforce_sha256_checking); + +ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyImpl( const CastMessage& challenge_reply, X509* peer_cert, - const AuthContext& auth_context) { + const AuthContext& auth_context, + const CRLPolicy& crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time) { DeviceAuthMessage auth_message; - openscreen::Error result = ParseAuthMessage(challenge_reply, &auth_message); + Error result = ParseAuthMessage(challenge_reply, &auth_message); if (!result.ok()) { return result; } - result = VerifyTLSCertificateValidity( - peer_cert, openscreen::platform::GetWallTimeSinceUnixEpoch()); + result = VerifyTLSCertificateValidity(peer_cert, + DateTimeToSeconds(verification_time)); if (!result.ok()) { return result; } @@ -246,22 +256,48 @@ ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply( return result; } - int len = i2d_X509(peer_cert, nullptr); - if (len <= 0) { - return openscreen::Error(CastCertError::kErrCertsParse, - "Serializing cert failed."); + int cert_len = i2d_X509(peer_cert, nullptr); + if (cert_len <= 0) { + return Error(Error::Code::kErrCertsParse, "Serializing cert failed."); } - std::string peer_cert_der(len, 0); - uint8_t* data = reinterpret_cast<uint8_t*>(&peer_cert_der[0]); + size_t nonce_response_size = nonce_response.size(); + std::vector<uint8_t> nonce_plus_peer_cert_der(nonce_response_size + cert_len, + 0); + std::copy(nonce_response.begin(), nonce_response.end(), + &nonce_plus_peer_cert_der[0]); + uint8_t* data = &nonce_plus_peer_cert_der[nonce_response_size]; if (!i2d_X509(peer_cert, &data)) { - return openscreen::Error(CastCertError::kErrCertsParse, - "Serializing cert failed."); + return Error(Error::Code::kErrCertsParse, "Serializing cert failed."); } - size_t actual_size = data - reinterpret_cast<uint8_t*>(&peer_cert_der[0]); - OSP_DCHECK_EQ(actual_size, peer_cert_der.size()); - peer_cert_der.resize(actual_size); - return VerifyCredentials(response, nonce_response + peer_cert_der); + return VerifyCredentialsImpl(response, nonce_plus_peer_cert_der, crl_policy, + cast_trust_store, crl_trust_store, + verification_time, false); +} + +ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply( + const CastMessage& challenge_reply, + X509* peer_cert, + const AuthContext& auth_context) { + DateTime now = {}; + OSP_CHECK(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now)); + CRLPolicy policy = CRLPolicy::kCrlOptional; + return AuthenticateChallengeReplyImpl( + challenge_reply, peer_cert, auth_context, policy, + /* cast_trust_store */ nullptr, /* crl_trust_store */ nullptr, now); +} + +ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyForTest( + const CastMessage& challenge_reply, + X509* peer_cert, + const AuthContext& auth_context, + CRLPolicy crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time) { + return AuthenticateChallengeReplyImpl( + challenge_reply, peer_cert, auth_context, crl_policy, cast_trust_store, + crl_trust_store, verification_time); } // This function does the following @@ -283,19 +319,18 @@ ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply( // |signature_input| by |response.client_auth_certificate|'s public key. ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl( const AuthResponse& response, - const std::string& signature_input, - const certificate::CRLPolicy& crl_policy, - certificate::TrustStore* cast_trust_store, - certificate::TrustStore* crl_trust_store, - const certificate::DateTime& verification_time, + const std::vector<uint8_t>& signature_input, + const CRLPolicy& crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time, bool enforce_sha256_checking) { if (response.signature().empty() && !signature_input.empty()) { - return openscreen::Error(CastCertError::kCastV2SignatureEmpty, - "Signature is empty."); + return Error(Error::Code::kCastV2SignatureEmpty, "Signature is empty."); } // Verify the certificate - std::unique_ptr<certificate::CertVerificationContext> verification_context; + std::unique_ptr<CertVerificationContext> verification_context; // Build a single vector containing the certificate chain. std::vector<std::string> cert_chain; @@ -305,43 +340,41 @@ ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl( response.intermediate_certificate().end()); // Parse the CRL. - std::unique_ptr<certificate::CastCRL> crl; + std::unique_ptr<CastCRL> crl; if (!response.crl().empty()) { - crl = certificate::ParseAndVerifyCRL(response.crl(), verification_time, - crl_trust_store); + crl = ParseAndVerifyCRL(response.crl(), verification_time, crl_trust_store); } // Perform certificate verification. - certificate::CastDeviceCertPolicy device_policy; - openscreen::Error verify_result = certificate::VerifyDeviceCert( - cert_chain, verification_time, &verification_context, &device_policy, - crl.get(), crl_policy, cast_trust_store); + CastDeviceCertPolicy device_policy; + Error verify_result = + VerifyDeviceCert(cert_chain, verification_time, &verification_context, + &device_policy, crl.get(), crl_policy, cast_trust_store); // Handle and report errors. - openscreen::Error result = MapToOpenscreenError( - verify_result.code(), crl_policy == certificate::CRLPolicy::kCrlRequired); + Error result = MapToOpenscreenError(verify_result.code(), + crl_policy == CRLPolicy::kCrlRequired); if (!result.ok()) { return result; } // The certificate is verified at this point. - certificate::DigestAlgorithm digest_algorithm; - openscreen::Error digest_result = VerifyAndMapDigestAlgorithm( + DigestAlgorithm digest_algorithm; + Error digest_result = VerifyAndMapDigestAlgorithm( response.hash_algorithm(), &digest_algorithm, enforce_sha256_checking); if (!digest_result.ok()) { return digest_result; } - certificate::ConstDataSpan signature = { + ConstDataSpan signature = { reinterpret_cast<const uint8_t*>(response.signature().data()), static_cast<uint32_t>(response.signature().size())}; - certificate::ConstDataSpan siginput = { - reinterpret_cast<const uint8_t*>(signature_input.data()), - static_cast<uint32_t>(signature_input.size())}; + ConstDataSpan siginput = {signature_input.data(), + static_cast<uint32_t>(signature_input.size())}; if (!verification_context->VerifySignatureOverData(signature, siginput, digest_algorithm)) { - return openscreen::Error(CastCertError::kCastV2SignedBlobsMismatch, - "Failed verifying signature over data."); + return Error(Error::Code::kCastV2SignedBlobsMismatch, + "Failed verifying signature over data."); } return device_policy; @@ -349,31 +382,29 @@ ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl( ErrorOr<CastDeviceCertPolicy> VerifyCredentials( const AuthResponse& response, - const std::string& signature_input, + const std::vector<uint8_t>& signature_input, bool enforce_revocation_checking, bool enforce_sha256_checking) { - certificate::DateTime now = {}; - OSP_CHECK(certificate::DateTimeFromSeconds( - openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now)); - certificate::CRLPolicy policy = (enforce_revocation_checking) - ? certificate::CRLPolicy::kCrlRequired - : certificate::CRLPolicy::kCrlOptional; + DateTime now = {}; + OSP_CHECK(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now)); + CRLPolicy policy = (enforce_revocation_checking) ? CRLPolicy::kCrlRequired + : CRLPolicy::kCrlOptional; return VerifyCredentialsImpl(response, signature_input, policy, nullptr, nullptr, now, enforce_sha256_checking); } ErrorOr<CastDeviceCertPolicy> VerifyCredentialsForTest( const AuthResponse& response, - const std::string& signature_input, - certificate::CRLPolicy crl_policy, - certificate::TrustStore* cast_trust_store, - certificate::TrustStore* crl_trust_store, - const certificate::DateTime& verification_time, + const std::vector<uint8_t>& signature_input, + CRLPolicy crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time, bool enforce_sha256_checking) { return VerifyCredentialsImpl(response, signature_input, crl_policy, cast_trust_store, crl_trust_store, verification_time, enforce_sha256_checking); } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h index 35f1d028f5f..4ca6df5e17c 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h +++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h @@ -7,28 +7,26 @@ #include <openssl/x509.h> +#include <chrono> // NOLINT #include <string> +#include <vector> #include "cast/common/certificate/cast_cert_validator.h" -#include "cast/common/channel/proto/cast_channel.pb.h" #include "platform/base/error.h" namespace cast { -namespace certificate { -enum class CRLPolicy; -struct DateTime; -struct TrustStore; -} // namespace certificate -} // namespace cast - -namespace cast { namespace channel { - class AuthResponse; class CastMessage; +} // namespace channel +} // namespace cast + +namespace openscreen { +namespace cast { -using openscreen::ErrorOr; -using CastDeviceCertPolicy = certificate::CastDeviceCertPolicy; +enum class CRLPolicy; +struct DateTime; +struct TrustStore; class AuthContext { public: @@ -40,9 +38,8 @@ class AuthContext { // Verifies the nonce received in the response is equivalent to the one sent. // Returns success if |nonce_response| matches nonce_ - openscreen::Error VerifySenderNonce( - const std::string& nonce_response, - bool enforce_nonce_checking = false) const; + Error VerifySenderNonce(const std::string& nonce_response, + bool enforce_nonce_checking = false) const; // The nonce challenge. const std::string& nonce() const { return nonce_; } @@ -55,40 +52,52 @@ class AuthContext { // Authenticates the given |challenge_reply|: // 1. Signature contained in the reply is valid. -// 2. Certficate used to sign is rooted to a trusted CA. +// 2. certificate used to sign is rooted to a trusted CA. ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply( - const CastMessage& challenge_reply, + const ::cast::channel::CastMessage& challenge_reply, X509* peer_cert, const AuthContext& auth_context); -// Performs a quick check of the TLS certificate for time validity requirements. -openscreen::Error VerifyTLSCertificateValidity( +// Exposed for testing only. +// +// Overloaded version of AuthenticateChallengeReply that allows modifying the +// crl policy, trust stores, and verification times. +ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyForTest( + const ::cast::channel::CastMessage& challenge_reply, X509* peer_cert, - std::chrono::seconds verification_time); + const AuthContext& auth_context, + CRLPolicy crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time); + +// Performs a quick check of the TLS certificate for time validity requirements. +Error VerifyTLSCertificateValidity(X509* peer_cert, + std::chrono::seconds verification_time); // Auth-library specific implementation of cryptographic signature verification // routines. Verifies that |response| contains a valid signature of // |signature_input|. ErrorOr<CastDeviceCertPolicy> VerifyCredentials( - const AuthResponse& response, - const std::string& signature_input, + const ::cast::channel::AuthResponse& response, + const std::vector<uint8_t>& signature_input, bool enforce_revocation_checking = false, bool enforce_sha256_checking = false); // Exposed for testing only. // -// Overloaded version of VerifyCredentials that allows modifying -// the crl policy, trust stores, and verification times. +// Overloaded version of VerifyCredentials that allows modifying the crl policy, +// trust stores, and verification times. ErrorOr<CastDeviceCertPolicy> VerifyCredentialsForTest( - const AuthResponse& response, - const std::string& signature_input, - certificate::CRLPolicy crl_policy, - certificate::TrustStore* cast_trust_store, - certificate::TrustStore* crl_trust_store, - const certificate::DateTime& verification_time, + const ::cast::channel::AuthResponse& response, + const std::vector<uint8_t>& signature_input, + CRLPolicy crl_policy, + TrustStore* cast_trust_store, + TrustStore* crl_trust_store, + const DateTime& verification_time, bool enforce_sha256_checking = false); -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_SENDER_CHANNEL_CAST_AUTH_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc index 6ea4ea6a141..0b76c18ba3d 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc @@ -9,19 +9,27 @@ #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/certificate/cast_crl.h" #include "cast/common/certificate/proto/test_suite.pb.h" -#include "cast/common/certificate/test_helpers.h" +#include "cast/common/certificate/testing/test_helpers.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "gtest/gtest.h" #include "platform/api/time.h" +#include "testing/util/read_file.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace channel { + +// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to +// openscreen::cast +using DeviceCertTestSuite = ::cast::certificate::DeviceCertTestSuite; +using VerificationResult = ::cast::certificate::VerificationResult; +using DeviceCertTest = ::cast::certificate::DeviceCertTest; + namespace { -using ErrorCode = openscreen::Error::Code; +using ::cast::channel::AuthResponse; -bool ConvertTimeSeconds(const certificate::DateTime& time, uint64_t* seconds) { +bool ConvertTimeSeconds(const DateTime& time, uint64_t* seconds) { static constexpr uint64_t kDaysPerYear = 365; static constexpr uint64_t kHoursPerDay = 24; static constexpr uint64_t kMinutesPerHour = 60; @@ -101,7 +109,7 @@ bool ConvertTimeSeconds(const certificate::DateTime& time, uint64_t* seconds) { #define TEST_DATA_PREFIX OPENSCREEN_TEST_DATA_DIR "cast/common/certificate/" -class CastAuthUtilTest : public testing::Test { +class CastAuthUtilTest : public ::testing::Test { public: CastAuthUtilTest() {} ~CastAuthUtilTest() override {} @@ -109,141 +117,151 @@ class CastAuthUtilTest : public testing::Test { void SetUp() override {} protected: - static AuthResponse CreateAuthResponse(std::string* signed_data, - HashAlgorithm digest_algorithm) { - std::vector<std::string> chain = - certificate::testing::ReadCertificatesFromPemFile( - TEST_DATA_PREFIX "certificates/chromecast_gen1.pem"); + static AuthResponse CreateAuthResponse( + std::vector<uint8_t>* signed_data, + ::cast::channel::HashAlgorithm digest_algorithm) { + std::vector<std::string> chain = testing::ReadCertificatesFromPemFile( + TEST_DATA_PREFIX "certificates/chromecast_gen1.pem"); OSP_CHECK(!chain.empty()); - certificate::testing::SignatureTestData signatures = - certificate::testing::ReadSignatureTestData( - TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem"); + testing::SignatureTestData signatures = testing::ReadSignatureTestData( + TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem"); AuthResponse response; response.set_client_auth_certificate(chain[0]); - for (size_t i = 1; i < chain.size(); ++i) + for (size_t i = 1; i < chain.size(); ++i) { response.add_intermediate_certificate(chain[i]); + } response.set_hash_algorithm(digest_algorithm); switch (digest_algorithm) { - case SHA1: + case ::cast::channel::SHA1: response.set_signature( std::string(reinterpret_cast<const char*>(signatures.sha1.data), signatures.sha1.length)); break; - case SHA256: + case ::cast::channel::SHA256: response.set_signature( std::string(reinterpret_cast<const char*>(signatures.sha256.data), signatures.sha256.length)); break; } - signed_data->assign(reinterpret_cast<const char*>(signatures.message.data), - signatures.message.length); + *signed_data = std::vector<uint8_t>( + signatures.message.data, + signatures.message.data + signatures.message.length); return response; } // Mangles a string by inverting the first byte. static void MangleString(std::string* str) { (*str)[0] = ~(*str)[0]; } + + // Mangles a vector by inverting the first byte. + static void MangleData(std::vector<uint8_t>* data) { + (*data)[0] = ~(*data)[0]; + } }; // Note on expiration: VerifyCredentials() depends on the system clock. In // practice this shouldn't be a problem though since the certificate chain // being verified doesn't expire until 2032. TEST_F(CastAuthUtilTest, VerifySuccess) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); - certificate::DateTime now = {}; - ASSERT_TRUE(certificate::DateTimeFromSeconds( - openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now)); - ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest( - auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr, - nullptr, now); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); + DateTime now = {}; + ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now)); + ErrorOr<CastDeviceCertPolicy> result = + VerifyCredentialsForTest(auth_response, signed_data, + CRLPolicy::kCrlOptional, nullptr, nullptr, now); EXPECT_TRUE(result); - EXPECT_EQ(certificate::CastDeviceCertPolicy::kUnrestricted, result.value()); + EXPECT_EQ(CastDeviceCertPolicy::kUnrestricted, result.value()); } TEST_F(CastAuthUtilTest, VerifyBadCA) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); MangleString(auth_response.mutable_intermediate_certificate(0)); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentials(auth_response, signed_data); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kErrCertsParse, result.error().code()); + EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyBadClientAuthCert) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); MangleString(auth_response.mutable_client_auth_certificate()); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentials(auth_response, signed_data); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kErrCertsParse, result.error().code()); + EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyBadSignature) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); MangleString(auth_response.mutable_signature()); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentials(auth_response, signed_data); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2SignedBlobsMismatch, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyEmptySignature) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); auth_response.mutable_signature()->clear(); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentials(auth_response, signed_data); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2SignatureEmpty, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2SignatureEmpty, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyUnsupportedDigest) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA1); - certificate::DateTime now = {}; - ASSERT_TRUE(certificate::DateTimeFromSeconds( - openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now)); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA1); + DateTime now = {}; + ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now)); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest( - auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr, - nullptr, now, true); + auth_response, signed_data, CRLPolicy::kCrlOptional, nullptr, nullptr, + now, true); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2DigestUnsupported, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2DigestUnsupported, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyBackwardsCompatibleDigest) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA1); - certificate::DateTime now = {}; - ASSERT_TRUE(certificate::DateTimeFromSeconds( - openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now)); - ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest( - auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr, - nullptr, now); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA1); + DateTime now = {}; + ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now)); + ErrorOr<CastDeviceCertPolicy> result = + VerifyCredentialsForTest(auth_response, signed_data, + CRLPolicy::kCrlOptional, nullptr, nullptr, now); EXPECT_TRUE(result); } TEST_F(CastAuthUtilTest, VerifyBadPeerCert) { - std::string signed_data; - AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256); - MangleString(&signed_data); + std::vector<uint8_t> signed_data; + AuthResponse auth_response = + CreateAuthResponse(&signed_data, ::cast::channel::SHA256); + MangleData(&signed_data); ErrorOr<CastDeviceCertPolicy> result = VerifyCredentials(auth_response, signed_data); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2SignedBlobsMismatch, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code()); } TEST_F(CastAuthUtilTest, VerifySenderNonceMatch) { AuthContext context = AuthContext::Create(); - const openscreen::Error result = - context.VerifySenderNonce(context.nonce(), true); + const Error result = context.VerifySenderNonce(context.nonce(), true); EXPECT_TRUE(result.ok()); } @@ -254,7 +272,7 @@ TEST_F(CastAuthUtilTest, VerifySenderNonceMismatch) { ErrorOr<CastDeviceCertPolicy> result = context.VerifySenderNonce(received_nonce, true); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2SenderNonceMismatch, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code()); } TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) { @@ -264,40 +282,36 @@ TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) { ErrorOr<CastDeviceCertPolicy> result = context.VerifySenderNonce(received_nonce, true); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2SenderNonceMismatch, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code()); } TEST_F(CastAuthUtilTest, VerifyTLSCertificateSuccess) { - std::vector<std::string> tls_cert_der = - certificate::testing::ReadCertificatesFromPemFile( - TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); + std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size()); - certificate::DateTime not_before; - certificate::DateTime not_after; - ASSERT_TRUE( - certificate::GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); + DateTime not_before; + DateTime not_after; + ASSERT_TRUE(GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); uint64_t x; ASSERT_TRUE(ConvertTimeSeconds(not_before, &x)); std::chrono::seconds s(x); - const openscreen::Error result = VerifyTLSCertificateValidity(tls_cert, s); + const Error result = VerifyTLSCertificateValidity(tls_cert, s); EXPECT_TRUE(result.ok()); X509_free(tls_cert); } TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) { - std::vector<std::string> tls_cert_der = - certificate::testing::ReadCertificatesFromPemFile( - TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); + std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size()); - certificate::DateTime not_before; - certificate::DateTime not_after; - ASSERT_TRUE( - certificate::GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); + DateTime not_before; + DateTime not_after; + ASSERT_TRUE(GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); uint64_t x; ASSERT_TRUE(ConvertTimeSeconds(not_before, &x)); std::chrono::seconds s(x - 1); @@ -305,22 +319,20 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) { ErrorOr<CastDeviceCertPolicy> result = VerifyTLSCertificateValidity(tls_cert, s); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2TlsCertValidStartDateInFuture, + EXPECT_EQ(Error::Code::kCastV2TlsCertValidStartDateInFuture, result.error().code()); X509_free(tls_cert); } TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) { - std::vector<std::string> tls_cert_der = - certificate::testing::ReadCertificatesFromPemFile( - TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); + std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile( + TEST_DATA_PREFIX "certificates/test_tls_cert.pem"); std::string& der_cert = tls_cert_der[0]; const uint8_t* data = (const uint8_t*)der_cert.data(); X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size()); - certificate::DateTime not_before; - certificate::DateTime not_after; - ASSERT_TRUE( - certificate::GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); + DateTime not_before; + DateTime not_after; + ASSERT_TRUE(GetCertValidTimeRange(tls_cert, ¬_before, ¬_after)); uint64_t x; ASSERT_TRUE(ConvertTimeSeconds(not_after, &x)); std::chrono::seconds s(x + 2); @@ -328,7 +340,7 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) { ErrorOr<CastDeviceCertPolicy> result = VerifyTLSCertificateValidity(tls_cert, s); EXPECT_FALSE(result); - EXPECT_EQ(ErrorCode::kCastV2TlsCertExpired, result.error().code()); + EXPECT_EQ(Error::Code::kCastV2TlsCertExpired, result.error().code()); X509_free(tls_cert); } @@ -346,39 +358,40 @@ enum TestStepResult { ErrorOr<CastDeviceCertPolicy> TestVerifyRevocation( const std::vector<std::string>& certificate_chain, const std::string& crl_bundle, - const certificate::DateTime& verification_time, + const DateTime& verification_time, bool crl_required, - certificate::TrustStore* cast_trust_store, - certificate::TrustStore* crl_trust_store) { + TrustStore* cast_trust_store, + TrustStore* crl_trust_store) { AuthResponse response; if (certificate_chain.size() > 0) { response.set_client_auth_certificate(certificate_chain[0]); - for (size_t i = 1; i < certificate_chain.size(); ++i) + for (size_t i = 1; i < certificate_chain.size(); ++i) { response.add_intermediate_certificate(certificate_chain[i]); + } } response.set_crl(crl_bundle); - certificate::CRLPolicy crl_policy = certificate::CRLPolicy::kCrlRequired; + CRLPolicy crl_policy = CRLPolicy::kCrlRequired; if (!crl_required && crl_bundle.empty()) - crl_policy = certificate::CRLPolicy::kCrlOptional; - ErrorOr<CastDeviceCertPolicy> result = - VerifyCredentialsForTest(response, "", crl_policy, cast_trust_store, - crl_trust_store, verification_time); + crl_policy = CRLPolicy::kCrlOptional; + ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest( + response, std::vector<uint8_t>(), crl_policy, cast_trust_store, + crl_trust_store, verification_time); // This test doesn't set the signature so it will just fail there. EXPECT_FALSE(result); return result; } // Runs a single test case. -bool RunTest(const certificate::DeviceCertTest& test_case) { - std::unique_ptr<certificate::TrustStore> crl_trust_store; - std::unique_ptr<certificate::TrustStore> cast_trust_store; +bool RunTest(const DeviceCertTest& test_case) { + std::unique_ptr<TrustStore> crl_trust_store; + std::unique_ptr<TrustStore> cast_trust_store; if (test_case.use_test_trust_anchors()) { - crl_trust_store = certificate::testing::CreateTrustStoreFromPemFile( + crl_trust_store = testing::CreateTrustStoreFromPemFile( TEST_DATA_PREFIX "certificates/cast_crl_test_root_ca.pem"); - cast_trust_store = certificate::testing::CreateTrustStoreFromPemFile( + cast_trust_store = testing::CreateTrustStoreFromPemFile( TEST_DATA_PREFIX "certificates/cast_test_root_ca.pem"); EXPECT_FALSE(crl_trust_store->certs.empty()); @@ -391,51 +404,49 @@ bool RunTest(const certificate::DeviceCertTest& test_case) { } // CastAuthUtil verifies the CRL at the same time as the certificate. - certificate::DateTime verification_time; + DateTime verification_time; uint64_t cert_verify_time = test_case.cert_verification_time_seconds(); if (!cert_verify_time) { cert_verify_time = test_case.crl_verification_time_seconds(); } - OSP_DCHECK( - certificate::DateTimeFromSeconds(cert_verify_time, &verification_time)); + OSP_DCHECK(DateTimeFromSeconds(cert_verify_time, &verification_time)); std::string crl_bundle = test_case.crl_bundle(); - ErrorOr<CastDeviceCertPolicy> result( - certificate::CastDeviceCertPolicy::kUnrestricted); + ErrorOr<CastDeviceCertPolicy> result(CastDeviceCertPolicy::kUnrestricted); switch (test_case.expected_result()) { - case certificate::PATH_VERIFICATION_FAILED: + case ::cast::certificate::PATH_VERIFICATION_FAILED: result = TestVerifyRevocation( certificate_chain, crl_bundle, verification_time, false, cast_trust_store.get(), crl_trust_store.get()); EXPECT_EQ(result.error().code(), - ErrorCode::kCastV2CertNotSignedByTrustedCa); + Error::Code::kCastV2CertNotSignedByTrustedCa); return result.error().code() == - ErrorCode::kCastV2CertNotSignedByTrustedCa; - case certificate::CRL_VERIFICATION_FAILED: + Error::Code::kCastV2CertNotSignedByTrustedCa; + case ::cast::certificate::CRL_VERIFICATION_FAILED: // Fall-through intended. - case certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL: + case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL: result = TestVerifyRevocation( certificate_chain, crl_bundle, verification_time, true, cast_trust_store.get(), crl_trust_store.get()); - EXPECT_EQ(result.error().code(), ErrorCode::kErrCrlInvalid); - return result.error().code() == ErrorCode::kErrCrlInvalid; - case certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: + EXPECT_EQ(result.error().code(), Error::Code::kErrCrlInvalid); + return result.error().code() == Error::Code::kErrCrlInvalid; + case ::cast::certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: // By-pass this test because CRL is always verified at the time the // certificate is verified. return true; - case certificate::REVOCATION_CHECK_FAILED: + case ::cast::certificate::REVOCATION_CHECK_FAILED: result = TestVerifyRevocation( certificate_chain, crl_bundle, verification_time, true, cast_trust_store.get(), crl_trust_store.get()); - EXPECT_EQ(result.error().code(), ErrorCode::kErrCertsRevoked); - return result.error().code() == ErrorCode::kErrCertsRevoked; - case certificate::SUCCESS: + EXPECT_EQ(result.error().code(), Error::Code::kErrCertsRevoked); + return result.error().code() == Error::Code::kErrCertsRevoked; + case ::cast::certificate::SUCCESS: result = TestVerifyRevocation( certificate_chain, crl_bundle, verification_time, false, cast_trust_store.get(), crl_trust_store.get()); - EXPECT_EQ(result.error().code(), ErrorCode::kCastV2SignedBlobsMismatch); - return result.error().code() == ErrorCode::kCastV2SignedBlobsMismatch; - case certificate::UNSPECIFIED: + EXPECT_EQ(result.error().code(), Error::Code::kCastV2SignedBlobsMismatch); + return result.error().code() == Error::Code::kCastV2SignedBlobsMismatch; + case ::cast::certificate::UNSPECIFIED: return false; } return false; @@ -446,9 +457,8 @@ bool RunTest(const certificate::DeviceCertTest& test_case) { // To see the description of the test, execute the test. // These tests are generated by a test generator in google3. void RunTestSuite(const std::string& test_suite_file_name) { - std::string testsuite_raw = - certificate::testing::ReadEntireFileToString(test_suite_file_name); - certificate::DeviceCertTestSuite test_suite; + std::string testsuite_raw = ReadEntireFileToString(test_suite_file_name); + DeviceCertTestSuite test_suite; EXPECT_TRUE(test_suite.ParseFromString(testsuite_raw)); uint16_t successes = 0; @@ -467,5 +477,5 @@ TEST_F(CastAuthUtilTest, CRLTestSuite) { } } // namespace -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc index ab3ed5d807d..6d96b730197 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc +++ b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc @@ -5,9 +5,14 @@ #include "cast/sender/channel/message_util.h" #include "cast/sender/channel/cast_auth_util.h" +#include "util/json/json_serialization.h" +namespace openscreen { namespace cast { -namespace channel { + +using ::cast::channel::AuthChallenge; +using ::cast::channel::CastMessage; +using ::cast::channel::DeviceAuthMessage; CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) { CastMessage message; @@ -15,7 +20,7 @@ CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) { AuthChallenge* challenge = auth_message.mutable_challenge(); challenge->set_sender_nonce(auth_context.nonce()); - challenge->set_hash_algorithm(SHA256); + challenge->set_hash_algorithm(::cast::channel::SHA256); std::string auth_message_string; auth_message.SerializeToString(&auth_message_string); @@ -24,11 +29,39 @@ CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) { message.set_source_id(kPlatformSenderId); message.set_destination_id(kPlatformReceiverId); message.set_namespace_(kAuthNamespace); - message.set_payload_type(CastMessage_PayloadType_BINARY); + message.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY); message.set_payload_binary(auth_message_string); return message; } -} // namespace channel +ErrorOr<CastMessage> CreateAppAvailabilityRequest(const std::string& sender_id, + int request_id, + const std::string& app_id) { + Json::Value dict(Json::ValueType::objectValue); + dict[kMessageKeyType] = Json::Value( + CastMessageTypeToString(CastMessageType::kGetAppAvailability)); + Json::Value app_id_value(Json::ValueType::arrayValue); + app_id_value.append(Json::Value(app_id)); + dict[kMessageKeyAppId] = std::move(app_id_value); + dict[kMessageKeyRequestId] = Json::Value(request_id); + + CastMessage message; + message.set_payload_type(::cast::channel::CastMessage_PayloadType_STRING); + ErrorOr<std::string> serialized = json::Stringify(dict); + if (serialized.is_error()) { + return serialized.error(); + } + message.set_payload_utf8(serialized.value()); + + message.set_protocol_version( + ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0); + message.set_source_id(sender_id); + message.set_destination_id(kPlatformReceiverId); + message.set_namespace_(kReceiverNamespace); + + return message; +} + } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h index e2da0cd8215..1a8c7717e32 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h +++ b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h @@ -7,15 +7,23 @@ #include "cast/common/channel/message_util.h" #include "cast/common/channel/proto/cast_channel.pb.h" +#include "platform/base/error.h" +namespace openscreen { namespace cast { -namespace channel { class AuthContext; -CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context); +::cast::channel::CastMessage CreateAuthChallengeMessage( + const AuthContext& auth_context); + +// |request_id| must be unique for |sender_id|. +ErrorOr<::cast::channel::CastMessage> CreateAppAvailabilityRequest( + const std::string& sender_id, + int request_id, + const std::string& app_id); -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc index 3e75c9776ee..bf89de88fae 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc +++ b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc @@ -5,34 +5,47 @@ #include "cast/sender/channel/sender_socket_factory.h" #include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/sender/channel/message_util.h" #include "platform/base/tls_connect_options.h" #include "util/crypto/certificate_utils.h" +#include "util/logging.h" -namespace cast { -namespace channel { +using ::cast::channel::CastMessage; -using openscreen::platform::TlsConnectOptions; +namespace openscreen { +namespace cast { bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a, - uint32_t b) { + int b) { return a && a->socket->socket_id() < b; } -bool operator<(uint32_t a, +bool operator<(int a, const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) { return b && a < b->socket->socket_id(); } -SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) { +SenderSocketFactory::SenderSocketFactory(Client* client, + TaskRunner* task_runner) + : client_(client), task_runner_(task_runner) { OSP_DCHECK(client); + OSP_DCHECK(task_runner); } -SenderSocketFactory::~SenderSocketFactory() = default; +SenderSocketFactory::~SenderSocketFactory() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); +} + +void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) { + OSP_DCHECK(factory); + factory_ = factory; +} void SenderSocketFactory::Connect(const IPEndpoint& endpoint, DeviceMediaPolicy media_policy, CastSocket::Client* client) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(client); auto it = FindPendingConnection(endpoint); if (it == pending_connections_.end()) { @@ -64,15 +77,15 @@ void SenderSocketFactory::OnConnected( CastSocket::Client* client = it->client; pending_connections_.erase(it); - ErrorOr<bssl::UniquePtr<X509>> peer_cert = openscreen::ImportCertificate( - der_x509_peer_cert.data(), der_x509_peer_cert.size()); + ErrorOr<bssl::UniquePtr<X509>> peer_cert = + ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size()); if (!peer_cert) { client_->OnError(this, endpoint, peer_cert.error()); return; } - auto socket = std::make_unique<CastSocket>(std::move(connection), this, - GetNextSocketId()); + auto socket = + MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this); pending_auth_.emplace_back( new PendingAuth{endpoint, media_policy, std::move(socket), client, AuthContext::Create(), std::move(peer_cert.value())}); @@ -149,22 +162,27 @@ void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) { } ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply( - message, (*it)->peer_cert.get(), (*it)->auth_context); + message, pending->peer_cert.get(), pending->auth_context); if (policy_or_error.is_error()) { + OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint + << " with error: " << policy_or_error.error(); client_->OnError(this, pending->endpoint, policy_or_error.error()); return; } if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly && - pending->media_policy != DeviceMediaPolicy::kAudioOnly) { + pending->media_policy == DeviceMediaPolicy::kIncludesVideo) { client_->OnError(this, pending->endpoint, Error::Code::kCastV2ChannelPolicyMismatch); return; } + pending->socket->set_audio_only(policy_or_error.value() == + CastDeviceCertPolicy::kAudioOnly); pending->socket->SetClient(pending->client); - client_->OnConnected(this, pending->endpoint, std::move(pending->socket)); + client_->OnConnected(this, pending->endpoint, + std::unique_ptr<CastSocket>(pending->socket.release())); } -} // namespace channel } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h index 63998674bf9..7c788536187 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h +++ b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h @@ -12,19 +12,15 @@ #include <vector> #include "cast/common/channel/cast_socket.h" +#include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/sender/channel/cast_auth_util.h" +#include "platform/api/task_runner.h" #include "platform/api/tls_connection_factory.h" #include "platform/base/ip_address.h" -#include "util/logging.h" +#include "util/serial_delete_ptr.h" +namespace openscreen { namespace cast { -namespace channel { - -using openscreen::Error; -using openscreen::IPEndpoint; -using openscreen::IPEndpointComparator; -using openscreen::platform::TlsConnection; -using openscreen::platform::TlsConnectionFactory; class SenderSocketFactory final : public TlsConnectionFactory::Client, public CastSocket::Client { @@ -40,19 +36,25 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client, }; enum class DeviceMediaPolicy { + kNone = 0, kAudioOnly, kIncludesVideo, }; - // |client| must outlive |this|. - explicit SenderSocketFactory(Client* client); + // |client| and |task_runner| must outlive |this|. + SenderSocketFactory(Client* client, TaskRunner* task_runner); ~SenderSocketFactory(); - void set_factory(TlsConnectionFactory* factory) { - OSP_DCHECK(factory); - factory_ = factory; - } + // |factory| cannot be nullptr and must outlive |this|. + void set_factory(TlsConnectionFactory* factory); + // Begins connecting to a Cast device at |endpoint|. If a successful + // connection is made, including device authentication, the new CastSocket + // will be passed to |client_|'s OnConnected method. The new CastSocket will + // have its client set to |client|. If any part of the connection process + // fails, |client_|'s OnError method is called instead. This includes if the + // device's media policy, as determined by authentication, is audio-only and + // |media_policy| is kIncludesVideo. void Connect(const IPEndpoint& endpoint, DeviceMediaPolicy media_policy, CastSocket::Client* client); @@ -78,29 +80,31 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client, struct PendingAuth { IPEndpoint endpoint; DeviceMediaPolicy media_policy; - std::unique_ptr<CastSocket> socket; + SerialDeletePtr<CastSocket> socket; CastSocket::Client* client; AuthContext auth_context; bssl::UniquePtr<X509> peer_cert; }; - friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b); - friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b); + friend bool operator<(const std::unique_ptr<PendingAuth>& a, int b); + friend bool operator<(int a, const std::unique_ptr<PendingAuth>& b); std::vector<PendingConnection>::iterator FindPendingConnection( const IPEndpoint& endpoint); // CastSocket::Client overrides. void OnError(CastSocket* socket, Error error) override; - void OnMessage(CastSocket* socket, CastMessage message) override; + void OnMessage(CastSocket* socket, + ::cast::channel::CastMessage message) override; Client* const client_; + TaskRunner* const task_runner_; TlsConnectionFactory* factory_ = nullptr; std::vector<PendingConnection> pending_connections_; std::vector<std::unique_ptr<PendingAuth>> pending_auth_; }; -} // namespace channel } // namespace cast +} // namespace openscreen #endif // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/public/DEPS b/chromium/third_party/openscreen/src/cast/sender/public/DEPS index 44de6584c7b..8a61a25853f 100644 --- a/chromium/third_party/openscreen/src/cast/sender/public/DEPS +++ b/chromium/third_party/openscreen/src/cast/sender/public/DEPS @@ -1,12 +1,9 @@ # -*- Mode: Python; -*- include_rules = [ - # By default, openscreen implementation libraries should not be exposed - # through public APIs. - '-base', - '-platform', - # Dependencies on the implementation are not allowed in public/. '-cast/sender', - '+cast/sender/public' + '+cast/sender/public', + '-cast/common', + '+cast/common/public', ] diff --git a/chromium/third_party/openscreen/src/cast/sender/public/README.md b/chromium/third_party/openscreen/src/cast/sender/public/README.md index ace86163571..b670d1100a6 100644 --- a/chromium/third_party/openscreen/src/cast/sender/public/README.md +++ b/chromium/third_party/openscreen/src/cast/sender/public/README.md @@ -1,4 +1,4 @@ -# cast/receiver/public +# cast/sender/public This module contains an implementation of the Cast "sender", i.e. the client that discovers Cast devices on the LAN and launches apps on them. diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc new file mode 100644 index 00000000000..c561b3865b7 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc @@ -0,0 +1,48 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/public/cast_app_discovery_service.h" + +namespace openscreen { +namespace cast { + +CastAppDiscoveryService::Subscription::Subscription( + CastAppDiscoveryService* discovery_service, + uint32_t id) + : discovery_service_(discovery_service), id_(id) {} + +CastAppDiscoveryService::Subscription::Subscription(Subscription&& other) + : discovery_service_(other.discovery_service_), id_(other.id_) { + other.discovery_service_ = nullptr; +} + +CastAppDiscoveryService::Subscription::~Subscription() { + Reset(); +} + +CastAppDiscoveryService::Subscription& CastAppDiscoveryService::Subscription:: +operator=(Subscription other) { + Swap(other); + return *this; +} + +void CastAppDiscoveryService::Subscription::Reset() { + if (discovery_service_) { + discovery_service_->RemoveAvailabilityCallback(id_); + } + discovery_service_ = nullptr; +} + +void CastAppDiscoveryService::Subscription::Swap(Subscription& other) { + CastAppDiscoveryService* service = other.discovery_service_; + other.discovery_service_ = discovery_service_; + discovery_service_ = service; + + uint32_t id = other.id_; + other.id_ = id_; + id_ = id; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h new file mode 100644 index 00000000000..ccb586ff816 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h @@ -0,0 +1,75 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_ +#define CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_ + +#include <vector> + +#include "cast/common/public/service_info.h" + +namespace openscreen { +namespace cast { + +class CastMediaSource; + +// Interface for app discovery for Cast devices. +class CastAppDiscoveryService { + public: + using AvailabilityCallback = + std::function<void(const CastMediaSource& source, + const std::vector<ServiceInfo>& devices)>; + + class Subscription { + public: + Subscription(Subscription&&); + ~Subscription(); + Subscription& operator=(Subscription); + + void Reset(); + + private: + friend class CastAppDiscoveryService; + + Subscription(CastAppDiscoveryService* discovery_service, uint32_t id); + + void Swap(Subscription& other); + + CastAppDiscoveryService* discovery_service_; + uint32_t id_; + }; + + virtual ~CastAppDiscoveryService() = default; + + // Adds an availability query for |source|. Results will be continuously + // returned via |callback| until the returned Subscription is destroyed by the + // caller. If there are cached results available, |callback| will be invoked + // before this method returns. |callback| may be invoked with an empty list + // if all devices respond to the respective queries with "unavailable" or + // don't respond before a timeout. |callback| may be invoked successively + // with the same list. + virtual Subscription StartObservingAvailability( + const CastMediaSource& source, + AvailabilityCallback callback) = 0; + + // Refreshes the state of app discovery in the service. It is suitable to call + // this method when the user initiates a user gesture. + virtual void Refresh() = 0; + + protected: + Subscription MakeSubscription(CastAppDiscoveryService* discovery_service, + uint32_t id) { + return Subscription(discovery_service, id); + } + + private: + friend class Subscription; + + virtual void RemoveAvailabilityCallback(uint32_t id) = 0; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc new file mode 100644 index 00000000000..ebc9a3fe014 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc @@ -0,0 +1,45 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/public/cast_media_source.h" + +#include <algorithm> + +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +// static +ErrorOr<CastMediaSource> CastMediaSource::From(const std::string& source) { + // TODO(btolsch): Implement when we have URL parsing. + OSP_UNIMPLEMENTED(); + return Error::Code::kUnknownError; +} + +CastMediaSource::CastMediaSource(std::string source, + std::vector<std::string> app_ids) + : source_id_(std::move(source)), app_ids_(std::move(app_ids)) {} + +CastMediaSource::CastMediaSource(const CastMediaSource& other) = default; +CastMediaSource::CastMediaSource(CastMediaSource&& other) = default; + +CastMediaSource::~CastMediaSource() = default; + +CastMediaSource& CastMediaSource::operator=(const CastMediaSource& other) = + default; +CastMediaSource& CastMediaSource::operator=(CastMediaSource&& other) = default; + +bool CastMediaSource::ContainsAppId(const std::string& app_id) const { + return std::find(app_ids_.begin(), app_ids_.end(), app_id) != app_ids_.end(); +} + +bool CastMediaSource::ContainsAnyAppIdFrom( + const std::vector<std::string>& app_ids) const { + return std::find_first_of(app_ids_.begin(), app_ids_.end(), app_ids.begin(), + app_ids.end()) != app_ids_.end(); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h new file mode 100644 index 00000000000..18af80f8694 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h @@ -0,0 +1,42 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_ +#define CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_ + +#include <string> +#include <vector> + +#include "platform/base/error.h" + +namespace openscreen { +namespace cast { + +class CastMediaSource { + public: + static ErrorOr<CastMediaSource> From(const std::string& source); + + CastMediaSource(std::string source, std::vector<std::string> app_ids); + CastMediaSource(const CastMediaSource& other); + CastMediaSource(CastMediaSource&& other); + ~CastMediaSource(); + + CastMediaSource& operator=(const CastMediaSource& other); + CastMediaSource& operator=(CastMediaSource&& other); + + bool ContainsAppId(const std::string& app_id) const; + bool ContainsAnyAppIdFrom(const std::vector<std::string>& app_ids) const; + + const std::string& source_id() const { return source_id_; } + const std::vector<std::string>& app_ids() const { return app_ids_; } + + private: + std::string source_id_; + std::vector<std::string> app_ids_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_ diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/DEPS b/chromium/third_party/openscreen/src/cast/sender/testing/DEPS new file mode 100644 index 00000000000..99039c594a4 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/testing/DEPS @@ -0,0 +1,4 @@ +include_rules = [ + # Sender tests can use receiver code for simulation/validation. + '+cast/receiver', +] diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc new file mode 100644 index 00000000000..ff9a45751ad --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc @@ -0,0 +1,87 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/sender/testing/test_helpers.h" + +#include "cast/common/channel/message_util.h" +#include "cast/receiver/channel/message_util.h" +#include "cast/sender/channel/message_util.h" +#include "gtest/gtest.h" +#include "util/json/json_serialization.h" +#include "util/json/json_value.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +using ::cast::channel::CastMessage; + +void VerifyAppAvailabilityRequest(const CastMessage& message, + const std::string& expected_app_id, + int* request_id_out, + std::string* sender_id_out) { + std::string app_id_out; + VerifyAppAvailabilityRequest(message, &app_id_out, request_id_out, + sender_id_out); + EXPECT_EQ(app_id_out, expected_app_id); +} + +void VerifyAppAvailabilityRequest(const CastMessage& message, + std::string* app_id_out, + int* request_id_out, + std::string* sender_id_out) { + EXPECT_EQ(message.namespace_(), kReceiverNamespace); + EXPECT_EQ(message.destination_id(), kPlatformReceiverId); + EXPECT_EQ(message.payload_type(), + ::cast::channel::CastMessage_PayloadType_STRING); + EXPECT_NE(message.source_id(), kPlatformSenderId); + *sender_id_out = message.source_id(); + + ErrorOr<Json::Value> maybe_value = json::Parse(message.payload_utf8()); + ASSERT_TRUE(maybe_value); + Json::Value& value = maybe_value.value(); + + absl::optional<absl::string_view> maybe_type = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); + ASSERT_TRUE(maybe_type); + EXPECT_EQ(maybe_type.value(), + CastMessageTypeToString(CastMessageType::kGetAppAvailability)); + + absl::optional<int> maybe_id = + MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyRequestId)); + ASSERT_TRUE(maybe_id); + *request_id_out = maybe_id.value(); + + const Json::Value* maybe_app_ids = + value.find(JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyAppId)); + ASSERT_TRUE(maybe_app_ids); + ASSERT_TRUE(maybe_app_ids->isArray()); + ASSERT_EQ(maybe_app_ids->size(), 1u); + Json::Value app_id_value = maybe_app_ids->get(0u, Json::Value("")); + absl::optional<absl::string_view> maybe_app_id = MaybeGetString(app_id_value); + ASSERT_TRUE(maybe_app_id); + *app_id_out = + std::string(maybe_app_id.value().begin(), maybe_app_id.value().end()); +} + +CastMessage CreateAppAvailableResponseChecked(int request_id, + const std::string& sender_id, + const std::string& app_id) { + ErrorOr<CastMessage> message = + CreateAppAvailableResponse(request_id, sender_id, app_id); + OSP_CHECK(message); + return std::move(message.value()); +} + +CastMessage CreateAppUnavailableResponseChecked(int request_id, + const std::string& sender_id, + const std::string& app_id) { + ErrorOr<CastMessage> message = + CreateAppUnavailableResponse(request_id, sender_id, app_id); + OSP_CHECK(message); + return std::move(message.value()); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h new file mode 100644 index 00000000000..c9a68c20e2a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h @@ -0,0 +1,43 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_SENDER_TESTING_TEST_HELPERS_H_ +#define CAST_SENDER_TESTING_TEST_HELPERS_H_ + +#include <cstdint> +#include <string> + +#include "cast/sender/channel/message_util.h" + +namespace cast { +namespace channel { +class CastMessage; +} // namespace channel +} // namespace cast + +namespace openscreen { +namespace cast { + +void VerifyAppAvailabilityRequest(const ::cast::channel::CastMessage& message, + const std::string& expected_app_id, + int* request_id_out, + std::string* sender_id_out); +void VerifyAppAvailabilityRequest(const ::cast::channel::CastMessage& message, + std::string* app_id_out, + int* request_id_out, + std::string* sender_id_out); + +::cast::channel::CastMessage CreateAppAvailableResponseChecked( + int request_id, + const std::string& sender_id, + const std::string& app_id); +::cast::channel::CastMessage CreateAppUnavailableResponseChecked( + int request_id, + const std::string& sender_id, + const std::string& app_id); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_SENDER_TESTING_TEST_HELPERS_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn index 696f732ba3b..ba54d33f414 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn @@ -2,39 +2,38 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//build/config/external_libraries.gni") import("//build_overrides/build.gni") -declare_args() { - # These are only relevant for building the demo apps, which require external - # headers/libraries be installed. Set them to true if your local system has - # SDL2/FFMPEG installed. On Debian-like systems, the following should install - # all the required headers and libraries: - # - # sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev \ - # libavformat libavformat-dev libavutil libavutil-dev - have_sdl_for_demo_apps = false - have_ffmpeg_for_demo_apps = false -} - # Define the executable target only when the build is configured to use the # standalone platform implementation; since this is itself a standalone # application. if (!build_with_chromium) { executable("cast_receiver") { sources = [ + "cast_agent.cc", + "cast_agent.h", + "cast_socket_message_port.cc", + "cast_socket_message_port.h", "main.cc", - "memory_util.h", + "streaming_playback_controller.cc", + "streaming_playback_controller.h", ] deps = [ "../../platform", + "../../third_party/jsoncpp", + "../common:public", + "../common/channel/proto:channel_proto", + "../receiver:channel", "../streaming:receiver", ] defines = [] include_dirs = [] + lib_dirs = [] libs = [] - if (have_sdl_for_demo_apps && have_ffmpeg_for_demo_apps) { - defines += [ "CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS" ] + if (have_ffmpeg && have_libsdl2) { + defines += [ "CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS" ] sources += [ "avcodec_glue.h", "decoder.cc", @@ -48,11 +47,9 @@ if (!build_with_chromium) { "sdl_video_player.cc", "sdl_video_player.h", ] - libs += [ - "-lSDL2", - "-lavcodec", - "-lavutil", - ] + include_dirs += ffmpeg_include_dirs + libsdl2_include_dirs + lib_dirs += ffmpeg_lib_dirs + libsdl2_lib_dirs + libs += ffmpeg_libs + libsdl2_libs } else { sources += [ "dummy_player.cc", diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS b/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS index 788c38cba93..f2040e13bd3 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS @@ -5,4 +5,6 @@ include_rules = [ '+cast', '+platform/impl', + '+discovery/common', + '+discovery/public', ] diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h index d1dfae78075..aa516175d34 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h @@ -14,8 +14,8 @@ extern "C" { #include <libavutil/samplefmt.h> } +namespace openscreen { namespace cast { -namespace streaming { // Macro that, for an AVFoo, generates code for: // @@ -50,7 +50,7 @@ DEFINE_AV_UNIQUE_PTR(AVFrame, av_frame_alloc, av_frame_free(&obj)); #undef DEFINE_AV_UNIQUE_PTR -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_AVCODEC_GLUE_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc new file mode 100644 index 00000000000..b614fa80889 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc @@ -0,0 +1,162 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/cast_agent.h" + +#include <fstream> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "cast/standalone_receiver/cast_socket_message_port.h" +#include "cast/standalone_receiver/private_key_der.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/offer_messages.h" +#include "platform/base/tls_credentials.h" +#include "platform/base/tls_listen_options.h" +#include "util/crypto/certificate_utils.h" +#include "util/json/json_serialization.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace { + +constexpr int kDefaultMaxBacklogSize = 64; +const TlsListenOptions kDefaultListenOptions{kDefaultMaxBacklogSize}; + +constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; +constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); + +// Generates a valid set of credentials for use with the TLS Server socket, +// including a generated X509 certificate generated from the static private key +// stored in private_key_der.h. The certificate is valid for +// kCertificateDuration from when this function is called. +ErrorOr<TlsCredentials> CreateCredentials(const IPEndpoint& endpoint) { + ErrorOr<bssl::UniquePtr<EVP_PKEY>> private_key = + ImportRSAPrivateKey(kPrivateKeyDer.data(), kPrivateKeyDer.size()); + OSP_CHECK(private_key); + + ErrorOr<bssl::UniquePtr<X509>> cert = CreateSelfSignedX509Certificate( + endpoint.ToString(), kCertificateDuration, *private_key.value()); + if (!cert) { + return cert.error(); + } + + auto cert_bytes = ExportX509CertificateToDer(*cert.value()); + if (!cert_bytes) { + return cert_bytes.error(); + } + + // TODO(jophba): either refactor the TLS server socket to use the public key + // and add a valid key here, or remove from the TlsCredentials struct. + return TlsCredentials( + std::vector<uint8_t>(kPrivateKeyDer.begin(), kPrivateKeyDer.end()), + std::vector<uint8_t>{}, std::move(cert_bytes.value())); +} + +} // namespace + +CastAgent::CastAgent(TaskRunner* task_runner, InterfaceInfo interface) + : task_runner_(task_runner) { + // Create the Environment that holds the required injected dependencies + // (clock, task runner) used throughout the system, and owns the UDP socket + // over which all communication occurs with the Sender. + IPEndpoint receive_endpoint{IPAddress::kV4LoopbackAddress, kDefaultCastPort}; + receive_endpoint.address = interface.GetIpAddressV4() + ? interface.GetIpAddressV4() + : interface.GetIpAddressV6(); + OSP_DCHECK(receive_endpoint.address); + environment_ = std::make_unique<Environment>(&Clock::now, task_runner_, + receive_endpoint); + receive_endpoint_ = std::move(receive_endpoint); +} + +CastAgent::~CastAgent() = default; + +Error CastAgent::Start() { + OSP_DCHECK(!current_session_); + + task_runner_->PostTask( + [this] { this->wake_lock_ = ScopedWakeLock::Create(); }); + + // TODO(jophba): add command line argument for setting the private key. + ErrorOr<TlsCredentials> credentials = CreateCredentials(receive_endpoint_); + if (!credentials) { + return credentials.error(); + } + + // TODO(jophba, rwkeane): begin discovery process before creating TLS + // connection factory instance. + socket_factory_ = + std::make_unique<ReceiverSocketFactory>(this, &message_port_); + task_runner_->PostTask([this, creds = std::move(credentials.value())] { + connection_factory_ = TlsConnectionFactory::CreateFactory( + socket_factory_.get(), task_runner_); + connection_factory_->SetListenCredentials(creds); + connection_factory_->Listen(receive_endpoint_, kDefaultListenOptions); + }); + + OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_; + return Error::None(); +} + +Error CastAgent::Stop() { + controller_.reset(); + current_session_.reset(); + return Error::None(); +} + +void CastAgent::OnConnected(ReceiverSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr<CastSocket> socket) { + OSP_DCHECK(factory); + + if (current_session_) { + OSP_LOG_WARN << "Already connected, dropping peer at: " << endpoint; + return; + } + + OSP_LOG_INFO << "Received connection from peer at: " << endpoint; + message_port_.SetSocket(std::move(socket)); + controller_ = + std::make_unique<StreamingPlaybackController>(task_runner_, this); + current_session_ = std::make_unique<ReceiverSession>( + controller_.get(), environment_.get(), &message_port_, + ReceiverSession::Preferences{}); +} + +void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) { + OSP_LOG_ERROR << "Cast agent received socket factory error: " << error; +} + +// Currently we don't do anything with the receiver output--the session +// is automatically linked to the playback controller when it is constructed, so +// we don't actually have to interface with the receivers. If we end up caring +// about the receiver configurations we will have to handle OnNegotiated here. +void CastAgent::OnNegotiated(const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers receivers) { + OSP_LOG_INFO << "Successfully negotiated with sender."; +} + +void CastAgent::OnConfiguredReceiversDestroyed(const ReceiverSession* session) { + OSP_LOG_INFO << "Receiver instances destroyed."; +} + +// Currently, we just kill the session if an error is encountered. +void CastAgent::OnError(const ReceiverSession* session, Error error) { + OSP_LOG_ERROR << "Cast agent received receiver session error: " << error; + current_session_.reset(); +} + +void CastAgent::OnPlaybackError(StreamingPlaybackController* controller, + Error error) { + OSP_LOG_ERROR << "Cast agent received playback error: " << error; + current_session_.reset(); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h new file mode 100644 index 00000000000..873563e6604 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h @@ -0,0 +1,84 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_CAST_AGENT_H_ +#define CAST_STANDALONE_RECEIVER_CAST_AGENT_H_ + +#include <openssl/x509.h> + +#include <memory> + +#include "cast/common/channel/cast_socket.h" +#include "cast/receiver/channel/receiver_socket_factory.h" +#include "cast/standalone_receiver/cast_socket_message_port.h" +#include "cast/standalone_receiver/streaming_playback_controller.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/receiver_session.h" +#include "platform/api/scoped_wake_lock.h" +#include "platform/base/error.h" +#include "platform/base/interface_info.h" +#include "platform/impl/task_runner.h" +#include "util/serial_delete_ptr.h" + +namespace openscreen { +namespace cast { + +// This class manages sender connections, starting with listening over TLS for +// connection attempts, constructing ReceiverSessions when OFFER messages are +// received, and linking Receivers to the output decoder and SDL visualizer. +// +// Consumers of this class are expected to provide a single threaded task runner +// implementation, and a network interface information struct that will be used +// both for TLS listening and UDP messaging. +class CastAgent : public ReceiverSocketFactory::Client, + public ReceiverSession::Client, + public StreamingPlaybackController::Client { + public: + CastAgent(TaskRunner* task_runner, InterfaceInfo interface); + ~CastAgent(); + + // Initialization occurs as part of construction, however to actually bind + // for discovery and listening over TLS, the CastAgent must be started. + Error Start(); + Error Stop(); + + // ReceiverSocketFactory::Client overrides. + void OnConnected(ReceiverSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr<CastSocket> socket) override; + void OnError(ReceiverSocketFactory* factory, Error error) override; + + // ReceiverSession::Client overrides. + void OnNegotiated(const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers receivers) override; + void OnConfiguredReceiversDestroyed(const ReceiverSession* session) override; + void OnError(const ReceiverSession* session, Error error) override; + + // StreamingPlaybackController::Client overrides + void OnPlaybackError(StreamingPlaybackController* controller, + Error error) override; + + private: + // Member variables set as part of construction. + std::unique_ptr<Environment> environment_; + TaskRunner* const task_runner_; + IPEndpoint receive_endpoint_; + CastSocketMessagePort message_port_; + + // Member variables set as part of starting up. + std::unique_ptr<TlsConnectionFactory> connection_factory_; + std::unique_ptr<ReceiverSocketFactory> socket_factory_; + std::unique_ptr<ScopedWakeLock> wake_lock_; + + // Member variables set as part of a sender connection. + // NOTE: currently we only support a single sender connection and a + // single streaming session. + std::unique_ptr<ReceiverSession> current_session_; + std::unique_ptr<StreamingPlaybackController> controller_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_CAST_AGENT_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc new file mode 100644 index 00000000000..706fd58ff6d --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc @@ -0,0 +1,59 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/cast_socket_message_port.h" + +#include <utility> + +namespace openscreen { +namespace cast { + +CastSocketMessagePort::CastSocketMessagePort() = default; +CastSocketMessagePort::~CastSocketMessagePort() = default; + +// NOTE: we assume here that this message port is already the client for +// the passed in socket, so leave the socket's client unchanged. However, +// since sockets should map one to one with receiver sessions, we reset our +// client. The consumer of this message port should call SetClient with the new +// message port client after setting the socket. +void CastSocketMessagePort::SetSocket(std::unique_ptr<CastSocket> socket) { + client_ = nullptr; + socket_ = std::move(socket); +} + +void CastSocketMessagePort::SetClient(MessagePort::Client* client) { + client_ = client; +} + +void CastSocketMessagePort::OnError(CastSocket* socket, Error error) { + if (client_) { + client_->OnError(error); + } +} + +void CastSocketMessagePort::OnMessage(CastSocket* socket, + ::cast::channel::CastMessage message) { + if (client_) { + client_->OnMessage(message.source_id(), message.namespace_(), + message.payload_utf8()); + } +} + +void CastSocketMessagePort::PostMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) { + ::cast::channel::CastMessage cast_message; + cast_message.set_source_id(sender_id.data(), sender_id.size()); + cast_message.set_namespace_(message_namespace.data(), + message_namespace.size()); + cast_message.set_payload_utf8(message.data(), message.size()); + + Error error = socket_->SendMessage(cast_message); + if (!error.ok()) { + client_->OnError(error); + } +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h new file mode 100644 index 00000000000..c596f985705 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h @@ -0,0 +1,44 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_ +#define CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "cast/common/channel/cast_socket.h" +#include "cast/streaming/receiver_session.h" + +namespace openscreen { +namespace cast { + +class CastSocketMessagePort : public MessagePort, public CastSocket::Client { + public: + CastSocketMessagePort(); + ~CastSocketMessagePort() override; + + void SetSocket(std::unique_ptr<CastSocket> socket); + + // MessagePort overrides. + void SetClient(MessagePort::Client* client) override; + void PostMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) override; + + // CastSocket::Client overrides. + void OnError(CastSocket* socket, Error error) override; + void OnMessage(CastSocket* socket, + ::cast::channel::CastMessage message) override; + + private: + MessagePort::Client* client_ = nullptr; + std::unique_ptr<CastSocket> socket_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc index 8421cd0f2e8..5d84d6a77fc 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc @@ -1,11 +1,18 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + #include "cast/standalone_receiver/decoder.h" +#include <algorithm> #include <sstream> +#include <thread> // NOLINT #include "util/logging.h" +#include "util/trace_logging.h" +namespace openscreen { namespace cast { -namespace streaming { Decoder::Buffer::Buffer() { Resize(0); @@ -36,16 +43,14 @@ absl::Span<uint8_t> Decoder::Buffer::GetSpan() { Decoder::Client::Client() = default; Decoder::Client::~Client() = default; -Decoder::Decoder() = default; +Decoder::Decoder(const std::string& codec_name) : codec_name_(codec_name) {} + Decoder::~Decoder() = default; void Decoder::Decode(FrameId frame_id, const Decoder::Buffer& buffer) { - if (!codec_) { - InitFromFirstBuffer(buffer); - if (!codec_) { - OnError("unable to detect codec", AVERROR(EINVAL), frame_id); - return; - } + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); + if (!codec_ && !Initialize()) { + return; } // Parse the buffer for the required metadata and the packet to send to the @@ -95,24 +100,64 @@ void Decoder::Decode(FrameId frame_id, const Decoder::Buffer& buffer) { } } -void Decoder::InitFromFirstBuffer(const Buffer& buffer) { - const AVCodecID codec_id = Detect(buffer); - if (codec_id == AV_CODEC_ID_NONE) { - return; +bool Decoder::Initialize() { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); + // NOTE: The codec_name values found in OFFER messages, such as "vp8" or + // "h264" or "opus" are valid input strings to FFMPEG's look-up function, so + // no translation is required here. + codec_ = avcodec_find_decoder_by_name(codec_name_.c_str()); + if (!codec_) { + HandleInitializationError("codec not available", AVERROR(EINVAL)); + return false; + } + OSP_LOG_INFO << "Found codec: " << codec_name_ << " (known to FFMPEG as " + << avcodec_get_name(codec_->id) << ')'; + + parser_ = MakeUniqueAVCodecParserContext(codec_->id); + if (!parser_) { + HandleInitializationError("failed to allocate parser context", + AVERROR(ENOMEM)); + return false; } - codec_ = avcodec_find_decoder(codec_id); - OSP_CHECK(codec_); - parser_ = MakeUniqueAVCodecParserContext(codec_id); - OSP_CHECK(parser_); context_ = MakeUniqueAVCodecContext(codec_); - OSP_CHECK(context_); + if (!context_) { + HandleInitializationError("failed to allocate codec context", + AVERROR(ENOMEM)); + return false; + } + + // This should always be greater than zero, so that decoding doesn't block the + // main thread of this receiver app and cause playback timing issues. The + // actual number should be tuned, based on the number of CPU cores. + // + // This should also be 16 or less, since the encoder implementations emit + // warnings about too many encode threads. FFMPEG's VP8 implementation + // actually silently freezes if this is 10 or more. Thus, 8 is used for the + // max here, just to be safe. + // + // TODO(jophba): determine a better number after running benchmarking. + context_->thread_count = + std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8); const int open_result = avcodec_open2(context_.get(), codec_, nullptr); - OSP_CHECK_EQ(open_result, 0); + if (open_result < 0) { + HandleInitializationError("failed to open codec", open_result); + return false; + } + packet_ = MakeUniqueAVPacket(); - OSP_CHECK(packet_); + if (!packet_) { + HandleInitializationError("failed to allocate AVPacket", AVERROR(ENOMEM)); + return false; + } + decoded_frame_ = MakeUniqueAVFrame(); - OSP_CHECK(decoded_frame_); + if (!decoded_frame_) { + HandleInitializationError("failed to allocate AVFrame", AVERROR(ENOMEM)); + return false; + } + + return true; } FrameId Decoder::DidReceiveFrameFromDecoder() { @@ -123,6 +168,26 @@ FrameId Decoder::DidReceiveFrameFromDecoder() { return frame_id; } +void Decoder::HandleInitializationError(const char* what, int av_errnum) { + // If the codec was found, get FFMPEG's canonical name for it. + const char* const canonical_name = + codec_ ? avcodec_get_name(codec_->id) : nullptr; + + codec_ = nullptr; // Set null to mean "not initialized." + + if (!client_) { + return; // Nowhere to emit error to, so don't bother. + } + + std::ostringstream error; + error << "Could not initialize codec " << codec_name_; + if (canonical_name) { + error << " (known to FFMPEG as " << canonical_name << ')'; + } + error << " because " << what << " (" << av_err2str(av_errnum) << ")."; + client_->OnFatalError(error.str()); +} + void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) { if (!client_) { return; @@ -133,7 +198,11 @@ void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) { if (!frame_id.is_null()) { error << "frame: " << frame_id << "; "; } - error << "what: " << what << "; error: " << av_err2str(av_errnum); + + char human_readable_error[AV_ERROR_MAX_STRING_SIZE]{0}; + av_make_error_string(human_readable_error, AV_ERROR_MAX_STRING_SIZE, + av_errnum); + error << "what: " << what << "; error: " << human_readable_error; // Dispatch to either the fatal error handler, or the one for decode errors, // as appropriate. @@ -149,52 +218,5 @@ void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) { } } -// static -AVCodecID Decoder::Detect(const Buffer& buffer) { - static constexpr AVCodecID kCodecsToTry[] = { - AV_CODEC_ID_VP8, AV_CODEC_ID_VP9, AV_CODEC_ID_H264, - AV_CODEC_ID_H265, AV_CODEC_ID_OPUS, AV_CODEC_ID_FLAC, - }; - - const absl::Span<const uint8_t> input = buffer.GetSpan(); - for (AVCodecID codec_id : kCodecsToTry) { - AVCodec* const codec = avcodec_find_decoder(codec_id); - if (!codec) { - OSP_LOG_WARN << "Video codec not available to try: " - << avcodec_get_name(codec_id); - continue; - } - const auto parser = MakeUniqueAVCodecParserContext(codec_id); - if (!parser) { - OSP_LOG_ERROR << "Failed to init parser for codec: " - << avcodec_get_name(codec_id); - continue; - } - const auto context = MakeUniqueAVCodecContext(codec); - if (!context || avcodec_open2(context.get(), codec, nullptr) != 0) { - OSP_LOG_ERROR << "Failed to open codec: " << avcodec_get_name(codec_id); - continue; - } - const auto packet = MakeUniqueAVPacket(); - OSP_CHECK(packet); - if (av_parser_parse2(parser.get(), context.get(), &packet->data, - &packet->size, input.data(), input.size(), - AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0) < 0 || - !packet->data || packet->size == 0) { - OSP_VLOG << "Does not parse as codec: " << avcodec_get_name(codec_id); - continue; - } - if (avcodec_send_packet(context.get(), packet.get()) < 0) { - OSP_VLOG << "avcodec_send_packet() failed, probably wrong codec version: " - << avcodec_get_name(codec_id); - continue; - } - OSP_VLOG << "Detected codec: " << avcodec_get_name(codec_id); - return codec_id; - } - - return AV_CODEC_ID_NONE; -} - -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h index 08d61e4335c..1d4d07917eb 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h @@ -14,10 +14,10 @@ #include "cast/standalone_receiver/avcodec_glue.h" #include "cast/streaming/frame_id.h" +namespace openscreen { namespace cast { -namespace streaming { -// Wraps libavcodec to auto-detect and decode audio or video. +// Wraps libavcodec to decode audio or video. class Decoder { public: // A buffer backed by storage that is compatible with FFMPEG (i.e., includes @@ -48,7 +48,8 @@ class Decoder { Client(); }; - Decoder(); + // |codec_name| should be the codec_name field from an OFFER message. + explicit Decoder(const std::string& codec_name); ~Decoder(); Client* client() const { return client_; } @@ -62,21 +63,23 @@ class Decoder { void Decode(FrameId frame_id, const Buffer& buffer); private: - // Helper to auto-detect the codec being used and initialize the FFMPEG - // decoder; called for the first frame being decoded. - void InitFromFirstBuffer(const Buffer& buffer); + // Helper to initialize the FFMPEG decoder and supporting objects. Returns + // false if this failed (and the Client was notified). + bool Initialize(); // Helper to get the FrameId that is associated with the next frame coming out // of the FFMPEG decoder. FrameId DidReceiveFrameFromDecoder(); - // Called when any transient or fatal error occurs, generating an - // openscreen::Error and notifying the Client of it. - void OnError(const char* what, int av_errnum, FrameId frame_id); + // Helper to handle a codec initialization error and notify the Client of the + // fatal error. + void HandleInitializationError(const char* what, int av_errnum); - // Auto-detects the codec needed to decode the data in |buffer|. - static AVCodecID Detect(const Buffer& buffer); + // Called when any transient or fatal error occurs, generating an Error and + // notifying the Client of it. + void OnError(const char* what, int av_errnum, FrameId frame_id); + const std::string codec_name_; AVCodec* codec_ = nullptr; AVCodecParserContextUniquePtr parser_; AVCodecContextUniquePtr context_; @@ -90,7 +93,7 @@ class Decoder { std::vector<FrameId> frames_decoding_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_DECODER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc index 8d2af1f74ba..cce96ee816b 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc @@ -12,8 +12,8 @@ using std::chrono::microseconds; +namespace openscreen { namespace cast { -namespace streaming { DummyPlayer::DummyPlayer(Receiver* receiver) : receiver_(receiver) { OSP_DCHECK(receiver_); @@ -41,5 +41,5 @@ void DummyPlayer::OnFramesReady(int buffer_size) { << buffer_size << " bytes"; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h index 373d32329fd..e8db1bf7a49 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h @@ -13,14 +13,14 @@ #include "platform/api/task_runner.h" #include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { // Consumes frames from a Receiver, but does nothing other than OSP_LOG_INFO // each one's FrameId, timestamp and size. This is only useful for confirming a // Receiver is successfully receiving a stream, for platforms where // SDLVideoPlayer cannot be built. -class DummyPlayer : public Receiver::Consumer { +class DummyPlayer final : public Receiver::Consumer { public: explicit DummyPlayer(Receiver* receiver); @@ -34,7 +34,7 @@ class DummyPlayer : public Receiver::Consumer { std::vector<uint8_t> buffer_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_DUMMY_PLAYER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh new file mode 100755 index 00000000000..c082455cd41 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +# Installs dependencies necessary for libSDL and libAVcodec on Debian systems. + +sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev \ + libavformat libavformat-dev libavutil libavutil-dev \ + libswresample libswresample-dev
\ No newline at end of file diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh new file mode 100755 index 00000000000..91acaaa6bfb --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env sh + +# Installs dependencies necessary for libSDL and libAVcodec on +# Raspberry PI units running Raspian. + +sudo apt-get install libavcodec58=7:4.1.4* libavcodec-dev=7:4.1.4* \ + libsdl2-2.0-0=2.0.9* libsdl2-dev=2.0.9* \ + libavformat-dev=7:4.1.4* diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc index 76e7a8d2c54..ed2dffd913d 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc @@ -2,181 +2,193 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <getopt.h> + #include <array> -#include <chrono> -#include <thread> +#include <chrono> // NOLINT -#include "cast/streaming/constants.h" -#include "cast/streaming/environment.h" -#include "cast/streaming/receiver.h" -#include "cast/streaming/receiver_packet_router.h" +#include "cast/common/public/service_info.h" +#include "cast/standalone_receiver/cast_agent.h" #include "cast/streaming/ssrc.h" +#include "discovery/common/config.h" +#include "discovery/common/reporting_client.h" +#include "discovery/public/dns_sd_service_factory.h" +#include "discovery/public/dns_sd_service_publisher.h" #include "platform/api/time.h" #include "platform/api/udp_socket.h" #include "platform/base/error.h" #include "platform/base/ip_address.h" #include "platform/impl/logging.h" +#include "platform/impl/network_interface.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/task_runner.h" +#include "platform/impl/text_trace_logging_platform.h" +#include "util/stringprintf.h" +#include "util/trace_logging.h" -#if defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS) -#include "cast/standalone_receiver/sdl_audio_player.h" -#include "cast/standalone_receiver/sdl_glue.h" -#include "cast/standalone_receiver/sdl_video_player.h" -#else -#include "cast/standalone_receiver/dummy_player.h" -#endif // defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS) - -using openscreen::IPEndpoint; -using openscreen::platform::Clock; -using openscreen::platform::TaskRunner; -using openscreen::platform::TaskRunnerImpl; - +namespace openscreen { namespace cast { -namespace streaming { namespace { -//////////////////////////////////////////////////////////////////////////////// -// Receiver Configuration -// -// The values defined here are constants that correspond to the Sender Demo app. -// In a production environment, these should ABSOLUTELY NOT be fixed! Instead a -// sender↔receiver OFFER/ANSWER exchange should establish them. - -const cast::streaming::SessionConfig kSampleAudioAnswerConfig{ - /* .sender_ssrc = */ 1, - /* .receiver_ssrc = */ 2, - - // In a production environment, this would be set to the sampling rate of - // the audio capture. - /* .rtp_timebase = */ 48000, - /* .channels = */ 2, - /* .aes_secret_key = */ - {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f}, - /* .aes_iv_mask = */ - {0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80, 0x70, 0x60, 0x50, 0x40, - 0x30, 0x20, 0x10, 0x00}, +class DiscoveryReportingClient : public discovery::ReportingClient { + void OnFatalError(Error error) override { + OSP_LOG_FATAL << "Encountered fatal discovery error: " << error; + } + + void OnRecoverableError(Error error) override { + OSP_LOG_ERROR << "Encountered recoverable discovery error: " << error; + } }; -const cast::streaming::SessionConfig kSampleVideoAnswerConfig{ - /* .sender_ssrc = */ 50001, - /* .receiver_ssrc = */ 50002, - /* .rtp_timebase = */ static_cast<int>(kVideoTimebase::den), - /* .channels = */ 1, - /* .aes_secret_key = */ - {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, - 0x1c, 0x1d, 0x1e, 0x1f}, - /* .aes_iv_mask = */ - {0xf1, 0xe1, 0xd1, 0xc1, 0xb1, 0xa1, 0x91, 0x81, 0x71, 0x61, 0x51, 0x41, - 0x31, 0x21, 0x11, 0x01}, +struct DiscoveryState { + SerialDeletePtr<discovery::DnsSdService> service; + std::unique_ptr<DiscoveryReportingClient> reporting_client; + std::unique_ptr<discovery::DnsSdServicePublisher<ServiceInfo>> publisher; }; -// In a production environment, this would start-out at some initial value -// appropriate to the networking environment, and then be adjusted by the -// application as: 1) the TYPE of the content changes (interactive, low-latency -// versus smooth, higher-latency buffered video watching); and 2) the networking -// environment reliability changes. -constexpr std::chrono::milliseconds kDemoTargetPlayoutDelay{400}; - -// The UDP socket port receiving packets from the Sender. -constexpr int kCastStreamingPort = 2344; - -// End of Receiver Configuration. -//////////////////////////////////////////////////////////////////////////////// - -void DemoMain(TaskRunnerImpl* task_runner) { - // Create the Environment that holds the required injected dependencies - // (clock, task runner) used throughout the system, and owns the UDP socket - // over which all communication occurs with the Sender. - const IPEndpoint receive_endpoint{openscreen::IPAddress(), - kCastStreamingPort}; - Environment env(&Clock::now, task_runner, receive_endpoint); - - // Create the packet router that allows both the Audio Receiver and the Video - // Receiver to share the same UDP socket. - ReceiverPacketRouter packet_router(&env); - - // Create the two Receivers. - Receiver audio_receiver(&env, &packet_router, kSampleAudioAnswerConfig, - kDemoTargetPlayoutDelay); - Receiver video_receiver(&env, &packet_router, kSampleVideoAnswerConfig, - kDemoTargetPlayoutDelay); - - OSP_LOG_INFO << "Awaiting first Cast Streaming packet at " - << env.GetBoundLocalEndpoint() << "..."; - -#if defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS) - - // Start the SDL event loop, using the task runner to poll/process events. - const ScopedSDLSubSystem<SDL_INIT_AUDIO> sdl_audio_sub_system; - const ScopedSDLSubSystem<SDL_INIT_VIDEO> sdl_video_sub_system; - const SDLEventLoopProcessor sdl_event_loop( - task_runner, [&] { task_runner->RequestStopSoon(); }); - - // Create/Initialize the Audio Player and Video Player, which are responsible - // for decoding and playing out the received media. - constexpr int kDefaultWindowWidth = 1280; - constexpr int kDefaultWindowHeight = 720; - const SDLWindowUniquePtr window = MakeUniqueSDLWindow( - "Cast Streaming Receiver Demo", - SDL_WINDOWPOS_UNDEFINED /* initial X position */, - SDL_WINDOWPOS_UNDEFINED /* initial Y position */, kDefaultWindowWidth, - kDefaultWindowHeight, SDL_WINDOW_RESIZABLE); - OSP_CHECK(window) << "Failed to create SDL window: " << SDL_GetError(); - const SDLRendererUniquePtr renderer = - MakeUniqueSDLRenderer(window.get(), -1, 0); - OSP_CHECK(renderer) << "Failed to create SDL renderer: " << SDL_GetError(); - - const SDLAudioPlayer audio_player( - &Clock::now, task_runner, &audio_receiver, [&] { - OSP_LOG_ERROR << audio_player.error_status().message(); - task_runner->RequestStopSoon(); - }); - const SDLVideoPlayer video_player( - &Clock::now, task_runner, &video_receiver, renderer.get(), [&] { - OSP_LOG_ERROR << video_player.error_status().message(); - task_runner->RequestStopSoon(); - }); - -#else - - const DummyPlayer audio_player(&audio_receiver); - const DummyPlayer video_player(&video_receiver); - -#endif // defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS) +ErrorOr<std::unique_ptr<DiscoveryState>> StartDiscovery( + TaskRunner* task_runner, + const InterfaceInfo& interface) { + discovery::Config config; + + config.interface = interface; + + auto state = std::make_unique<DiscoveryState>(); + state->reporting_client = std::make_unique<DiscoveryReportingClient>(); + state->service = discovery::CreateDnsSdService( + task_runner, state->reporting_client.get(), config); + + // TODO(jophba): update after ServiceInfo update patch lands. + ServiceInfo info; + info.port = kDefaultCastPort; + if (interface.GetIpAddressV4()) { + info.v4_address = interface.GetIpAddressV4(); + } + if (interface.GetIpAddressV6()) { + info.v6_address = interface.GetIpAddressV6(); + } + + OSP_CHECK(std::any_of(interface.hardware_address.begin(), + interface.hardware_address.end(), + [](int e) { return e > 0; })); + info.unique_id = HexEncode(interface.hardware_address); + + // TODO(jophba): add command line arguments to set these fields. + info.model_name = "cast_standalone_receiver"; + info.friendly_name = "Cast Standalone Receiver"; + + state->publisher = + std::make_unique<discovery::DnsSdServicePublisher<ServiceInfo>>( + state->service.get(), kCastV2ServiceId, ServiceInfoToDnsSdRecord); + + auto error = state->publisher->Register(info); + if (!error.ok()) { + return error; + } + return state; +} + +void RunStandaloneReceiver(TaskRunnerImpl* task_runner, + InterfaceInfo interface) { + CastAgent agent(task_runner, interface); + const auto error = agent.Start(); + if (!error.ok()) { + OSP_LOG_ERROR << "Error occurred while starting agent: " << error; + return; + } // Run the event loop until an exit is requested (e.g., the video player GUI - // window is closed, a SIGTERM is intercepted, or whatever other appropriate - // user indication that shutdown is requested). - task_runner->RunUntilStopped(); + // window is closed, a SIGINT or SIGTERM is received, or whatever other + // appropriate user indication that shutdown is requested). + task_runner->RunUntilSignaled(); } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen + +namespace { + +void LogUsage(const char* argv0) { + constexpr char kExecutableTag[] = "argv[0]"; + constexpr char kUsageMessage[] = R"( + usage: argv[0] <options> <interface> -int main(int argc, const char* argv[]) { - using openscreen::platform::PlatformClientPosix; - - class PlatformClientExposingTaskRunner : public PlatformClientPosix { - public: - explicit PlatformClientExposingTaskRunner( - std::unique_ptr<TaskRunner> task_runner) - : PlatformClientPosix(Clock::duration{50}, - Clock::duration{50}, - std::move(task_runner)) { - SetInstance(this); + options: + <interface>: Specify the network interface to bind to. The interface is + looked up from the system interface registry. This argument is + mandatory, as it must be known for publishing discovery. + + -t, --tracing: Enable performance tracing logging. + + -h, --help: Show this help message. + )"; + std::string message = kUsageMessage; + message.replace(message.find(kExecutableTag), strlen(kExecutableTag), argv0); + OSP_LOG_INFO << message; +} + +} // namespace + +int main(int argc, char* argv[]) { + // TODO(jophba): refactor into separate method and make main a one-liner. + using openscreen::Clock; + using openscreen::ErrorOr; + using openscreen::InterfaceInfo; + using openscreen::IPAddress; + using openscreen::IPEndpoint; + using openscreen::PlatformClientPosix; + using openscreen::TaskRunnerImpl; + + openscreen::SetLogLevel(openscreen::LogLevel::kInfo); + + const struct option argument_options[] = { + {"tracing", no_argument, nullptr, 't'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}}; + + InterfaceInfo interface_info; + std::unique_ptr<openscreen::TextTraceLoggingPlatform> trace_logger; + int ch = -1; + while ((ch = getopt_long(argc, argv, "th", argument_options, nullptr)) != + -1) { + switch (ch) { + case 't': + trace_logger = std::make_unique<openscreen::TextTraceLoggingPlatform>(); + break; + case 'h': + LogUsage(argv[0]); + return 1; } - }; + } + char* interface_argument = argv[optind]; + OSP_CHECK(interface_argument != nullptr) + << "Missing mandatory argument: interface."; + std::vector<InterfaceInfo> network_interfaces = + openscreen::GetNetworkInterfaces(); + for (auto& interface : network_interfaces) { + if (interface.name == interface_argument) { + interface_info = std::move(interface); + break; + } + } + OSP_CHECK(!interface_info.name.empty()) << "Invalid interface specified."; + + auto* const task_runner = new TaskRunnerImpl(&Clock::now); + PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}, + std::unique_ptr<TaskRunnerImpl>(task_runner)); - openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kInfo); - auto* const platform_client = new PlatformClientExposingTaskRunner( - std::make_unique<TaskRunnerImpl>(&Clock::now)); + auto discovery_state = + openscreen::cast::StartDiscovery(task_runner, interface_info); + OSP_CHECK(discovery_state.is_value()) << "Failed to start discovery."; - cast::streaming::DemoMain(static_cast<TaskRunnerImpl*>( - PlatformClientPosix::GetInstance()->GetTaskRunner())); + // Runs until the process is interrupted. Safe to pass |task_runner| as it + // will not be destroyed by ShutDown() until this exits. + openscreen::cast::RunStandaloneReceiver(task_runner, interface_info); - platform_client->ShutDown(); // Deletes |platform_client|. + // The task runner must be deleted after all serial delete pointers, such + // as the one stored in the discovery state. + discovery_state.value().reset(); + PlatformClientPosix::ShutDown(); return 0; } diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der b/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der Binary files differnew file mode 100644 index 00000000000..48aa17b0151 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h new file mode 100644 index 00000000000..0d9a328f01c --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h @@ -0,0 +1,125 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_ +#define CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_ + +#include <array> + +namespace openscreen { +namespace cast { + +// Important note about private keys and security: For example usage purposes, +// we have checked in a default private key here. However, in a production +// environment keys should never be checked into source control. This is an +// example self-signed private key for TLS. +// +// Generated using the following command: +// $ xxd -i <path/to/private_key.der> +std::array<uint8_t, 1192> kPrivateKeyDer = { + 0x30, 0x82, 0x04, 0xa4, 0x02, 0x01, 0x00, 0x02, 0x82, 0x01, 0x01, 0x00, + 0xb8, 0x07, 0xbb, 0x3f, 0xde, 0x47, 0xba, 0xec, 0x7a, 0xc2, 0x6e, 0xe5, + 0x44, 0x4a, 0xa8, 0x50, 0xd9, 0xef, 0xc1, 0x67, 0xc1, 0x5e, 0x14, 0xc0, + 0x1a, 0x1f, 0x8e, 0x83, 0xb9, 0xb2, 0x5d, 0x2d, 0x74, 0x12, 0x4b, 0x43, + 0x3c, 0xa6, 0xfb, 0xc4, 0x6c, 0xb0, 0xab, 0xda, 0xfb, 0xfd, 0x51, 0x18, + 0xc1, 0xc1, 0x22, 0x05, 0xf5, 0x2b, 0x10, 0xe4, 0x84, 0x1b, 0xa7, 0xdc, + 0xe6, 0xd0, 0xf5, 0x64, 0xa7, 0xc7, 0x6f, 0x1a, 0x85, 0xf8, 0xc4, 0xfb, + 0xbf, 0x17, 0x54, 0xaa, 0x11, 0x29, 0xfe, 0xda, 0x33, 0x9e, 0xfa, 0xb1, + 0x86, 0x9b, 0xf4, 0xcd, 0xf5, 0xe1, 0x3f, 0x7b, 0x3b, 0xad, 0xd9, 0xe2, + 0xc7, 0x6e, 0x4f, 0x1e, 0xa8, 0x13, 0x22, 0xa2, 0x7a, 0xcf, 0xe1, 0x8a, + 0x06, 0xf3, 0x28, 0x3a, 0xdc, 0xd3, 0x8c, 0x24, 0xa6, 0xe0, 0xd3, 0x5a, + 0x23, 0x21, 0x53, 0x02, 0x7d, 0x08, 0x30, 0xcb, 0xf1, 0x21, 0xca, 0x72, + 0x69, 0x49, 0x6e, 0x0f, 0xbc, 0x03, 0x7e, 0x0e, 0x60, 0x5d, 0x92, 0x08, + 0x3f, 0x04, 0x76, 0x62, 0x2d, 0x4b, 0xeb, 0x61, 0xaa, 0xe6, 0xcd, 0x2f, + 0x28, 0x24, 0xb0, 0xe8, 0xa5, 0xfe, 0x89, 0x90, 0xb8, 0xa4, 0x60, 0x6e, + 0x4c, 0x8c, 0x2f, 0xa3, 0xae, 0x72, 0xf2, 0x42, 0xb9, 0xc9, 0xa2, 0x6f, + 0x91, 0xbc, 0x75, 0x3b, 0x35, 0xb7, 0xe6, 0x24, 0xcb, 0x80, 0x8a, 0x34, + 0xfa, 0x9d, 0xf1, 0x7c, 0x88, 0x98, 0x09, 0x7b, 0x50, 0x56, 0xa5, 0x84, + 0x9c, 0x5f, 0x6c, 0x6e, 0x10, 0xfa, 0x95, 0xb6, 0xbf, 0xe4, 0xb0, 0x55, + 0x29, 0x3d, 0xe6, 0x3d, 0x14, 0xd7, 0x70, 0x17, 0xd8, 0xd3, 0xaa, 0xaf, + 0x4f, 0x15, 0x99, 0x63, 0xd0, 0x74, 0xfc, 0xb0, 0x6b, 0x66, 0x28, 0x02, + 0xd1, 0xbb, 0x01, 0x57, 0x02, 0xfe, 0x52, 0xe2, 0x0b, 0xbd, 0x8c, 0x0a, + 0x87, 0x8b, 0x60, 0xe9, 0x02, 0x03, 0x01, 0x00, 0x01, 0x02, 0x82, 0x01, + 0x01, 0x00, 0xaa, 0x19, 0x7b, 0x5a, 0x6d, 0x7a, 0x9f, 0xac, 0x35, 0x4b, + 0xc2, 0x74, 0xe7, 0xca, 0x9a, 0x09, 0x21, 0x68, 0x1a, 0xbc, 0x6c, 0x5f, + 0x29, 0x8e, 0xe6, 0x96, 0x84, 0x83, 0xfd, 0x00, 0x80, 0x5f, 0xa3, 0x09, + 0xc5, 0xc7, 0x40, 0x28, 0x98, 0x4d, 0xd6, 0xa8, 0xf6, 0x30, 0x52, 0xfa, + 0xb2, 0x1a, 0xcf, 0xfc, 0x54, 0x16, 0x6d, 0xa6, 0x80, 0xd6, 0xb7, 0xc5, + 0x58, 0x43, 0x36, 0x95, 0xae, 0x3c, 0x7b, 0x58, 0x3b, 0xb9, 0xa8, 0x5b, + 0x68, 0xb7, 0xc8, 0xc9, 0x27, 0xd8, 0x8a, 0x44, 0xe6, 0xeb, 0x89, 0x0b, + 0x49, 0x6d, 0x0d, 0x9e, 0xd9, 0x88, 0x05, 0xdd, 0x4d, 0x6f, 0xfa, 0x99, + 0x96, 0xeb, 0xa6, 0xaa, 0xaf, 0x37, 0x06, 0xe3, 0xa8, 0xff, 0xc5, 0xc4, + 0xa0, 0x13, 0x94, 0x98, 0xec, 0x76, 0x7b, 0xe6, 0x8d, 0x82, 0xd3, 0x3c, + 0xbc, 0x1e, 0x74, 0x9a, 0x38, 0xbf, 0xf4, 0x11, 0xbe, 0x07, 0x32, 0x2d, + 0x16, 0x2c, 0xf2, 0x5d, 0x24, 0x38, 0x70, 0xfb, 0x90, 0x8a, 0x38, 0xd6, + 0x17, 0xe1, 0x66, 0x92, 0x38, 0x06, 0x97, 0xb3, 0x07, 0xfd, 0x77, 0xe2, + 0xe7, 0x49, 0xae, 0x5a, 0xbc, 0xe5, 0xa8, 0xca, 0xe1, 0x0f, 0xb6, 0x4c, + 0x05, 0x73, 0x3f, 0x11, 0xd0, 0xf9, 0x1e, 0xba, 0x53, 0x48, 0xf5, 0xaf, + 0x28, 0x5b, 0xea, 0x12, 0x63, 0xbc, 0x84, 0xa7, 0x5f, 0x2e, 0x1d, 0x3e, + 0x02, 0x54, 0x58, 0xed, 0x2b, 0x42, 0xf9, 0xc6, 0x0c, 0xd4, 0x24, 0x77, + 0x1a, 0x2c, 0xbf, 0x75, 0x92, 0xf7, 0xcb, 0xd4, 0x58, 0x2f, 0x88, 0x2d, + 0xe8, 0x16, 0xca, 0xe5, 0x25, 0xe8, 0x5b, 0xbd, 0x53, 0x26, 0x23, 0xe0, + 0xa9, 0x35, 0x4d, 0xdb, 0x51, 0x85, 0x63, 0x20, 0xad, 0x61, 0xd2, 0x6d, + 0xbf, 0x01, 0x7d, 0x04, 0x44, 0x02, 0x96, 0x92, 0x36, 0x19, 0xed, 0xd1, + 0xd8, 0x16, 0x86, 0x06, 0xd4, 0x81, 0x02, 0x81, 0x81, 0x00, 0xe1, 0xa6, + 0xca, 0xb3, 0xef, 0xfe, 0x9f, 0xd6, 0xac, 0x58, 0x5c, 0x17, 0x88, 0xaf, + 0x4d, 0x85, 0x29, 0x50, 0x1f, 0x66, 0x90, 0x9b, 0x81, 0xb6, 0x82, 0x0d, + 0xc3, 0x5a, 0xa8, 0x8a, 0x2b, 0x7f, 0x58, 0x9b, 0x07, 0xe6, 0x64, 0xf7, + 0x1c, 0x77, 0x9d, 0x53, 0x97, 0xa0, 0x33, 0x14, 0x6e, 0x77, 0x1e, 0xe3, + 0x00, 0x0f, 0xb2, 0xb1, 0x69, 0x25, 0x3d, 0x63, 0x3c, 0xe1, 0xbb, 0x41, + 0x74, 0x97, 0x2d, 0x5e, 0x14, 0x79, 0x93, 0x38, 0x15, 0xbe, 0x52, 0x74, + 0x64, 0xc0, 0xfd, 0x22, 0x8e, 0xd7, 0xc9, 0xfb, 0x66, 0x55, 0xce, 0x5b, + 0x6a, 0x6f, 0x00, 0xed, 0x03, 0x7e, 0x4b, 0x9c, 0x4b, 0x8b, 0x3a, 0x50, + 0x65, 0x0d, 0x70, 0x9b, 0xdb, 0xf7, 0x1f, 0xd7, 0x66, 0x7a, 0xd1, 0x1e, + 0xa0, 0x8f, 0xe6, 0x03, 0x12, 0x18, 0x52, 0x25, 0x41, 0xa7, 0xb9, 0x8e, + 0x75, 0x63, 0x11, 0xd2, 0x63, 0xd7, 0x02, 0x81, 0x81, 0x00, 0xd0, 0xc7, + 0xe9, 0x97, 0x38, 0x33, 0x95, 0xbd, 0x18, 0xa5, 0x0a, 0x68, 0xab, 0xba, + 0x5e, 0x3e, 0x1f, 0x16, 0x86, 0xc0, 0x50, 0x09, 0xab, 0x52, 0xb7, 0x62, + 0x4e, 0x34, 0xb1, 0xc1, 0xd3, 0xb5, 0xf4, 0xe0, 0x04, 0x30, 0xa6, 0xdd, + 0x4a, 0xba, 0x7c, 0x59, 0xed, 0xd7, 0x76, 0xd3, 0x02, 0xe7, 0x05, 0x18, + 0x00, 0xdb, 0x65, 0xf2, 0x82, 0xe4, 0xfa, 0xbf, 0x9d, 0xad, 0x1a, 0x56, + 0x7b, 0x5e, 0xef, 0xff, 0x9b, 0xe5, 0x2f, 0x7c, 0xdd, 0x50, 0x53, 0x2b, + 0x6b, 0xc0, 0xac, 0x7b, 0x21, 0x8d, 0xc3, 0x39, 0xfe, 0xd0, 0x1a, 0xed, + 0xd1, 0xb6, 0x56, 0xda, 0x9e, 0x87, 0x9a, 0x6a, 0x69, 0x81, 0x29, 0x81, + 0x75, 0x69, 0xa6, 0x25, 0xc2, 0xf7, 0x5a, 0x94, 0x97, 0x6a, 0x7a, 0xf9, + 0x6c, 0xbe, 0x43, 0x76, 0x34, 0xba, 0x0c, 0x50, 0x6d, 0x22, 0xe8, 0xa6, + 0x9c, 0x80, 0x62, 0x87, 0xc9, 0x3f, 0x02, 0x81, 0x80, 0x78, 0xaf, 0x47, + 0x1c, 0x63, 0x90, 0x30, 0x16, 0x95, 0x88, 0x90, 0x80, 0x79, 0xb7, 0x20, + 0x63, 0xc6, 0xcb, 0xb6, 0x6f, 0x99, 0x89, 0xc2, 0x1f, 0x45, 0x81, 0x6c, + 0xe9, 0x10, 0xd9, 0x0d, 0x18, 0x87, 0xe0, 0x2a, 0xa2, 0x7b, 0x7f, 0x7a, + 0x77, 0x32, 0xea, 0xa1, 0x5e, 0xa9, 0xd3, 0x14, 0x9d, 0x9b, 0x24, 0x57, + 0x45, 0x0e, 0x12, 0x3a, 0xa5, 0x13, 0x26, 0xff, 0x49, 0xcf, 0x67, 0xdb, + 0x9e, 0x7b, 0x42, 0x24, 0xfb, 0x3c, 0xd4, 0xb3, 0x34, 0x5e, 0x4f, 0x28, + 0x0f, 0xdb, 0x92, 0xdf, 0x08, 0xe4, 0x5b, 0x13, 0xc9, 0x72, 0x9b, 0x8b, + 0xda, 0x20, 0x89, 0xa2, 0xe3, 0xaa, 0x36, 0xc6, 0x64, 0x89, 0x64, 0xb4, + 0x17, 0x33, 0x11, 0xf8, 0xdc, 0x3b, 0xe8, 0x6d, 0x43, 0xe4, 0x92, 0x57, + 0xd7, 0x7e, 0x72, 0x47, 0xfc, 0x3f, 0xfa, 0xf3, 0x19, 0x6c, 0x71, 0x97, + 0xb0, 0xcb, 0xb8, 0x55, 0x73, 0x02, 0x81, 0x81, 0x00, 0x98, 0xf1, 0xfa, + 0x73, 0x67, 0x1e, 0x93, 0x11, 0x45, 0xde, 0x91, 0xb3, 0x80, 0x2a, 0x35, + 0x23, 0xf9, 0x0e, 0x3d, 0x84, 0xe0, 0x9d, 0x54, 0xbe, 0x71, 0xcd, 0x38, + 0x51, 0x6d, 0xee, 0xfa, 0x33, 0x0f, 0xc2, 0x94, 0x0f, 0x38, 0x0e, 0x60, + 0xd2, 0x20, 0x8a, 0x98, 0xac, 0x01, 0x46, 0x2f, 0x98, 0x21, 0xa9, 0x25, + 0xe7, 0x93, 0xd5, 0x86, 0x82, 0x4c, 0x16, 0xd7, 0x61, 0x9a, 0x2b, 0xc4, + 0x91, 0x15, 0xec, 0x00, 0xbe, 0x72, 0x7d, 0x5c, 0x7b, 0x9d, 0x91, 0xef, + 0x8b, 0xe4, 0x4f, 0x07, 0x93, 0x9c, 0x72, 0xfd, 0xf2, 0x61, 0xe7, 0xda, + 0x7b, 0x63, 0x41, 0x20, 0x65, 0x62, 0x7f, 0x95, 0xee, 0xa3, 0x03, 0x4d, + 0x8a, 0x29, 0xc6, 0xfb, 0xfe, 0xcc, 0x82, 0x92, 0x31, 0xd5, 0x08, 0xa7, + 0xda, 0xf1, 0xfc, 0xc4, 0x3f, 0x8f, 0x09, 0xd4, 0x09, 0x80, 0xb9, 0x9d, + 0x68, 0x87, 0xc5, 0xc5, 0x6d, 0x02, 0x81, 0x80, 0x1f, 0xd9, 0x20, 0xde, + 0xba, 0xcd, 0x63, 0x34, 0x4f, 0x9f, 0xbb, 0x05, 0x0a, 0x8d, 0x20, 0xe1, + 0x66, 0x41, 0x2f, 0xae, 0xc7, 0xfa, 0x5d, 0xfd, 0xb7, 0x2a, 0x0f, 0xa6, + 0x6d, 0xf3, 0xad, 0x65, 0x54, 0x75, 0x2c, 0x26, 0x1e, 0xac, 0x1f, 0x24, + 0x4c, 0x83, 0xe3, 0x28, 0x08, 0x60, 0x74, 0xfe, 0xa9, 0x53, 0x36, 0x1e, + 0xb3, 0x39, 0x9d, 0xe7, 0x49, 0x03, 0x66, 0x61, 0xe8, 0xd4, 0xf4, 0xd8, + 0x65, 0x57, 0x01, 0xed, 0xaa, 0x7b, 0x6b, 0x04, 0xa2, 0x5f, 0xe1, 0x67, + 0xe6, 0x06, 0x7c, 0x84, 0x2a, 0x7d, 0x53, 0x03, 0x1c, 0x9c, 0x82, 0x08, + 0x37, 0x07, 0xaf, 0x77, 0xe0, 0x99, 0x69, 0xce, 0x01, 0x5a, 0x85, 0x4b, + 0x27, 0xb9, 0xb2, 0x20, 0x8c, 0xa5, 0xb9, 0x42, 0x2f, 0xad, 0x56, 0xdd, + 0xb9, 0x0d, 0x23, 0x05, 0x53, 0x5a, 0x26, 0x3a, 0xe2, 0x17, 0x58, 0x79, + 0x96, 0x8c, 0x5a, 0x05}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc index 8a58870eeea..d87bbdf396f 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc @@ -11,19 +11,14 @@ #include "cast/standalone_receiver/avcodec_glue.h" #include "util/big_endian.h" #include "util/logging.h" +#include "util/trace_logging.h" using std::chrono::duration_cast; using std::chrono::milliseconds; using std::chrono::seconds; -using openscreen::Error; -using openscreen::ErrorOr; -using openscreen::platform::Clock; -using openscreen::platform::ClockNowFunctionPtr; -using openscreen::platform::TaskRunner; - +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -44,6 +39,7 @@ void InterleaveAudioSamples(const uint8_t* const planes[], int num_channels, int num_samples, uint8_t* interleaved) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); // Note: This could be optimized with SIMD intrinsics for much better // performance. auto* dest = reinterpret_cast<Element*>(interleaved); @@ -61,10 +57,12 @@ void InterleaveAudioSamples(const uint8_t* const planes[], SDLAudioPlayer::SDLAudioPlayer(ClockNowFunctionPtr now_function, TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, std::function<void()> error_callback) : SDLPlayerBase(now_function, task_runner, receiver, + codec_name, std::move(error_callback), kAudioMediaType) {} @@ -76,6 +74,7 @@ SDLAudioPlayer::~SDLAudioPlayer() { ErrorOr<Clock::time_point> SDLAudioPlayer::RenderNextFrame( const SDLPlayerBase::PresentableFrame& next_frame) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); OSP_DCHECK(next_frame.decoded_frame); const AVFrame& frame = *next_frame.decoded_frame; @@ -171,6 +170,7 @@ bool SDLAudioPlayer::RenderWhileIdle(const PresentableFrame* frame) { } void SDLAudioPlayer::Present() { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); if (state() != kScheduledToPresent) { // In all other states, just do nothing. The SDL audio buffer will underrun // and result in silence. @@ -207,8 +207,6 @@ void SDLAudioPlayer::Present() { // static SDL_AudioFormat SDLAudioPlayer::GetSDLAudioFormat(AVSampleFormat format) { - using openscreen::IsBigEndianArchitecture; - switch (format) { case AV_SAMPLE_FMT_U8P: case AV_SAMPLE_FMT_U8: @@ -234,5 +232,5 @@ SDL_AudioFormat SDLAudioPlayer::GetSDLAudioFormat(AVSampleFormat format) { return kSDLAudioFormatUnknown; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h index d7a4b0ea3e7..8788d1f5c05 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h @@ -5,27 +5,31 @@ #ifndef CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_ #define CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_ +#include <string> +#include <vector> + #include "cast/standalone_receiver/sdl_player_base.h" +namespace openscreen { namespace cast { -namespace streaming { // Consumes frames from a Receiver, decodes them, and renders them to an // internally-owned SDL audio device. -class SDLAudioPlayer : public SDLPlayerBase { +class SDLAudioPlayer final : public SDLPlayerBase { public: // |error_callback| is run only if a fatal error occurs, at which point the // player has halted and set |error_status()|. - SDLAudioPlayer(openscreen::platform::ClockNowFunctionPtr now_function, - openscreen::platform::TaskRunner* task_runner, + SDLAudioPlayer(ClockNowFunctionPtr now_function, + TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, std::function<void()> error_callback); ~SDLAudioPlayer() final; private: // SDLPlayerBase implementation. - openscreen::ErrorOr<openscreen::platform::Clock::time_point> RenderNextFrame( + ErrorOr<Clock::time_point> RenderNextFrame( const SDLPlayerBase::PresentableFrame& frame) final; bool RenderWhileIdle(const SDLPlayerBase::PresentableFrame* frame) final; void Present() final; @@ -38,7 +42,7 @@ class SDLAudioPlayer : public SDLPlayerBase { // The amount of time before a target presentation time to call Present(), to // account for audio buffering (the latency until samples reach the hardware). - openscreen::platform::Clock::duration approximate_lead_time_{}; + Clock::duration approximate_lead_time_{}; // When the decoder provides planar data, this buffer is used for storing the // interleaved conversion. @@ -54,7 +58,7 @@ class SDLAudioPlayer : public SDLPlayerBase { SDL_AudioSpec device_spec_{}; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc index 4d12b0579d0..a77c1bf29c9 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc @@ -8,18 +8,15 @@ #include "platform/api/time.h" #include "util/logging.h" -using openscreen::platform::Clock; -using openscreen::platform::TaskRunner; - +namespace openscreen { namespace cast { -namespace streaming { SDLEventLoopProcessor::SDLEventLoopProcessor( TaskRunner* task_runner, std::function<void()> quit_callback) : alarm_(&Clock::now, task_runner), quit_callback_(std::move(quit_callback)) { - alarm_.Schedule([this] { ProcessPendingEvents(); }, {}); + alarm_.Schedule([this] { ProcessPendingEvents(); }, Alarm::kImmediately); } SDLEventLoopProcessor::~SDLEventLoopProcessor() = default; @@ -38,9 +35,8 @@ void SDLEventLoopProcessor::ProcessPendingEvents() { // Schedule a task to come back and process more pending events. constexpr auto kEventPollPeriod = std::chrono::milliseconds(10); - alarm_.Schedule([this] { ProcessPendingEvents(); }, - Clock::now() + kEventPollPeriod); + alarm_.ScheduleFromNow([this] { ProcessPendingEvents(); }, kEventPollPeriod); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h index 2900bf57557..59a3a02066c 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h @@ -18,13 +18,10 @@ #include "util/alarm.h" namespace openscreen { -namespace platform { + class TaskRunner; -} // namespace platform -} // namespace openscreen namespace cast { -namespace streaming { template <uint32_t subsystem> class ScopedSDLSubSystem { @@ -65,18 +62,18 @@ DEFINE_SDL_UNIQUE_PTR(Texture); // event is received. class SDLEventLoopProcessor { public: - SDLEventLoopProcessor(openscreen::platform::TaskRunner* task_runner, + SDLEventLoopProcessor(TaskRunner* task_runner, std::function<void()> quit_callback); ~SDLEventLoopProcessor(); private: void ProcessPendingEvents(); - openscreen::Alarm alarm_; + Alarm alarm_; std::function<void()> quit_callback_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_SDL_GLUE_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc index 0c5e6c9c787..2d2f7b61db8 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc @@ -12,28 +12,25 @@ #include "cast/streaming/encoded_frame.h" #include "util/big_endian.h" #include "util/logging.h" +#include "util/trace_logging.h" using std::chrono::duration_cast; using std::chrono::milliseconds; -using openscreen::Error; -using openscreen::ErrorOr; -using openscreen::platform::Clock; -using openscreen::platform::ClockNowFunctionPtr; -using openscreen::platform::TaskRunner; - +namespace openscreen { namespace cast { -namespace streaming { SDLPlayerBase::SDLPlayerBase(ClockNowFunctionPtr now_function, TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, std::function<void()> error_callback, const char* media_type) : now_(now_function), receiver_(receiver), error_callback_(std::move(error_callback)), media_type_(media_type), + decoder_(codec_name), decode_alarm_(now_, task_runner), render_alarm_(now_, task_runner), presentation_alarm_(now_, task_runner) { @@ -69,35 +66,16 @@ void SDLPlayerBase::OnFatalError(std::string message) { } } -void SDLPlayerBase::OnFramesReady(int buffer_size) { - // Do not consume anything if there are too many frames in the pipeline - // already. - if (static_cast<int>(frames_to_render_.size()) > kMaxFramesInPipeline) { - return; - } - - // Consume the next frame. - const Clock::time_point start_time = now_(); - buffer_.Resize(buffer_size); - EncodedFrame frame = receiver_->ConsumeNextFrame(buffer_.GetSpan()); - - // Create the tracking state for the frame in the player pipeline. - OSP_DCHECK_EQ(frames_to_render_.count(frame.frame_id), 0); - PendingFrame& pending_frame = frames_to_render_[frame.frame_id]; - pending_frame.start_time = start_time; - - // Determine the presentation time of the frame. Ideally, this will occur - // based on the time progression of the media, given by the RTP timestamps. - // However, if this falls too far out-of-sync with the system reference clock, - // re-syrchronize, possibly causing user-visible "jank." +Clock::time_point SDLPlayerBase::ResyncAndDeterminePresentationTime( + const EncodedFrame& frame) { constexpr auto kMaxPlayoutDrift = milliseconds(100); const auto media_time_since_last_sync = (frame.rtp_timestamp - last_sync_rtp_timestamp_) .ToDuration<Clock::duration>(receiver_->rtp_timebase()); - pending_frame.presentation_time = + Clock::time_point presentation_time = last_sync_reference_time_ + media_time_since_last_sync; - const auto drift = duration_cast<milliseconds>( - frame.reference_time - pending_frame.presentation_time); + const auto drift = + duration_cast<milliseconds>(frame.reference_time - presentation_time); if (drift > kMaxPlayoutDrift || drift < -kMaxPlayoutDrift) { // Only log if not the very first frame. OSP_LOG_IF(INFO, frame.frame_id != FrameId::first()) @@ -109,8 +87,30 @@ void SDLPlayerBase::OnFramesReady(int buffer_size) { // back into sync over several frames. last_sync_rtp_timestamp_ = frame.rtp_timestamp; last_sync_reference_time_ = frame.reference_time; - pending_frame.presentation_time = frame.reference_time; + presentation_time = frame.reference_time; } + return presentation_time; +} + +void SDLPlayerBase::OnFramesReady(int buffer_size) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); + // Do not consume anything if there are too many frames in the pipeline + // already. + if (static_cast<int>(frames_to_render_.size()) > kMaxFramesInPipeline) { + return; + } + + // Consume the next frame. + const Clock::time_point start_time = now_(); + buffer_.Resize(buffer_size); + EncodedFrame frame = receiver_->ConsumeNextFrame(buffer_.GetSpan()); + + // Create the tracking state for the frame in the player pipeline. + OSP_DCHECK_EQ(frames_to_render_.count(frame.frame_id), 0); + PendingFrame& pending_frame = frames_to_render_[frame.frame_id]; + pending_frame.start_time = start_time; + + pending_frame.presentation_time = ResyncAndDeterminePresentationTime(frame); // Start decoding the frame. This call may synchronously call back into the // AVCodecDecoder::Client methods in this class. @@ -118,6 +118,7 @@ void SDLPlayerBase::OnFramesReady(int buffer_size) { } void SDLPlayerBase::OnFrameDecoded(FrameId frame_id, const AVFrame& frame) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); const auto it = frames_to_render_.find(frame_id); if (it == frames_to_render_.end()) { return; @@ -142,6 +143,7 @@ void SDLPlayerBase::OnDecodeError(FrameId frame_id, std::string message) { } void SDLPlayerBase::RenderAndSchedulePresentation() { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); // If something has already been scheduled to present at an exact time point, // don't render anything new yet. if (state_ == kScheduledToPresent) { @@ -159,12 +161,12 @@ void SDLPlayerBase::RenderAndSchedulePresentation() { // The interval here, is "lengthy" from the program's perspective, but // reasonably "snappy" from the user's perspective. constexpr auto kIdlePresentInterval = milliseconds(250); - presentation_alarm_.Schedule( + presentation_alarm_.ScheduleFromNow( [this] { Present(); ResumeRendering(); }, - now_() + kIdlePresentInterval); + kIdlePresentInterval); } return; } @@ -227,11 +229,12 @@ void SDLPlayerBase::ResumeDecoding() { OnFramesReady(buffer_size); } }, - {}); + Alarm::kImmediately); } void SDLPlayerBase::ResumeRendering() { - render_alarm_.Schedule([this] { RenderAndSchedulePresentation(); }, {}); + render_alarm_.Schedule([this] { RenderAndSchedulePresentation(); }, + Alarm::kImmediately); } // static @@ -250,5 +253,5 @@ SDLPlayerBase::PendingFrame::PendingFrame(PendingFrame&&) noexcept = default; SDLPlayerBase::PendingFrame& SDLPlayerBase::PendingFrame::operator=( PendingFrame&&) noexcept = default; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h index 34fb278e258..7338edab1b0 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h @@ -18,8 +18,8 @@ #include "platform/api/time.h" #include "platform/base/error.h" +namespace openscreen { namespace cast { -namespace streaming { // Common base class that consumes frames from a Receiver, decodes them, and // plays them out via the appropriate SDL subsystem. Subclasses implement the @@ -29,7 +29,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { ~SDLPlayerBase() override; // Returns OK unless a fatal error has occurred. - const openscreen::Error& error_status() const { return error_status_; } + const Error& error_status() const { return error_status_; } protected: // Current player state, which is used to determine what to render/present, @@ -43,7 +43,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // A decoded frame and its target presentation time. struct PresentableFrame { - openscreen::platform::Clock::time_point presentation_time; + Clock::time_point presentation_time; AVFrameUniquePtr decoded_frame; PresentableFrame(); @@ -55,9 +55,10 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // |error_callback| is run only if a fatal error occurs, at which point the // player has halted and set |error_status()|. |media_type| should be "audio" // or "video" (only used when logging). - SDLPlayerBase(openscreen::platform::ClockNowFunctionPtr now_function, - openscreen::platform::TaskRunner* task_runner, + SDLPlayerBase(ClockNowFunctionPtr now_function, + TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, std::function<void()> error_callback, const char* media_type); @@ -68,8 +69,8 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { void OnFatalError(std::string message) final; // Renders the |frame| and returns its [possibly adjusted] presentation time. - virtual openscreen::ErrorOr<openscreen::platform::Clock::time_point> - RenderNextFrame(const PresentableFrame& frame) = 0; + virtual ErrorOr<Clock::time_point> RenderNextFrame( + const PresentableFrame& frame) = 0; // Called to render when the player has no new content, and returns true if a // Present() is necessary. |frame| may be null, if it is not available. This @@ -82,9 +83,25 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { virtual void Present() = 0; private: + struct PendingFrame : public PresentableFrame { + Clock::time_point start_time; + + PendingFrame(); + ~PendingFrame(); + PendingFrame(PendingFrame&& other) noexcept; + PendingFrame& operator=(PendingFrame&& other) noexcept; + }; + // Receiver::Consumer implementation. void OnFramesReady(int next_frame_buffer_size) final; + // Determine the presentation time of the frame. Ideally, this will occur + // based on the time progression of the media, given by the RTP timestamps. + // However, if this falls too far out-of-sync with the system reference clock, + // re-synchronize, possibly causing user-visible "jank." + Clock::time_point ResyncAndDeterminePresentationTime( + const EncodedFrame& frame); + // AVCodecDecoder::Client implementation. These are called-back from // |decoder_| to provide results. void OnFrameDecoded(FrameId frame_id, const AVFrame& frame) final; @@ -108,13 +125,13 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // require rendering/presenting a different output. void ResumeRendering(); - const openscreen::platform::ClockNowFunctionPtr now_; + const ClockNowFunctionPtr now_; Receiver* const receiver_; std::function<void()> error_callback_; // Run once by OnFatalError(). const char* const media_type_; // For logging only. // Set to the error code that placed the player in a fatal error state. - openscreen::Error error_status_; + Error error_status_; // Current player state, which is used to determine what to render/present, // and how frequently. @@ -122,14 +139,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // Queue of frames currently being decoded and decoded frames awaiting // rendering. - struct PendingFrame : public PresentableFrame { - openscreen::platform::Clock::time_point start_time; - PendingFrame(); - ~PendingFrame(); - PendingFrame(PendingFrame&& other) noexcept; - PendingFrame& operator=(PendingFrame&& other) noexcept; - }; std::map<FrameId, PendingFrame> frames_to_render_; // Buffer for holding EncodedFrame::data. @@ -139,7 +149,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // whenever the media (RTP) timestamps drift too much away from the rate at // which the local clock ticks. This is important for A/V synchronization. RtpTimeTicks last_sync_rtp_timestamp_{}; - openscreen::platform::Clock::time_point last_sync_reference_time_{}; + Clock::time_point last_sync_reference_time_{}; Decoder decoder_; @@ -149,13 +159,13 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { // A cumulative moving average of recent single-frame processing times // (consume + decode + render). This is passed to the Cast Receiver so that it // can determine when to drop late frames. - openscreen::platform::Clock::duration recent_processing_time_{}; + Clock::duration recent_processing_time_{}; // Alarms that execute the various stages of the player pipeline at certain // times. - openscreen::Alarm decode_alarm_; - openscreen::Alarm render_alarm_; - openscreen::Alarm presentation_alarm_; + Alarm decode_alarm_; + Alarm render_alarm_; + Alarm presentation_alarm_; // Maximum number of frames in the decode/render pipeline. This limit is about // making sure the player uses resources efficiently: It is better for frames @@ -164,7 +174,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client { static constexpr int kMaxFramesInPipeline = 8; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_SDL_PLAYER_BASE_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc index 250a76c05fb..f0fff9fadf6 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc @@ -8,15 +8,10 @@ #include "cast/standalone_receiver/avcodec_glue.h" #include "util/logging.h" +#include "util/trace_logging.h" -using openscreen::Error; -using openscreen::ErrorOr; -using openscreen::platform::Clock; -using openscreen::platform::ClockNowFunctionPtr; -using openscreen::platform::TaskRunner; - +namespace openscreen { namespace cast { -namespace streaming { namespace { constexpr char kVideoMediaType[] = "video"; @@ -25,11 +20,13 @@ constexpr char kVideoMediaType[] = "video"; SDLVideoPlayer::SDLVideoPlayer(ClockNowFunctionPtr now_function, TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, SDL_Renderer* renderer, std::function<void()> error_callback) : SDLPlayerBase(now_function, task_runner, receiver, + codec_name, std::move(error_callback), kVideoMediaType), renderer_(renderer) { @@ -40,6 +37,7 @@ SDLVideoPlayer::~SDLVideoPlayer() = default; bool SDLVideoPlayer::RenderWhileIdle( const SDLPlayerBase::PresentableFrame* frame) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); // Attempt to re-render the same content. if (state() == kPresented && frame) { const auto result = RenderNextFrame(*frame); @@ -69,6 +67,7 @@ bool SDLVideoPlayer::RenderWhileIdle( ErrorOr<Clock::time_point> SDLVideoPlayer::RenderNextFrame( const SDLPlayerBase::PresentableFrame& frame) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); OSP_DCHECK(frame.decoded_frame); const AVFrame& picture = *frame.decoded_frame; @@ -163,6 +162,7 @@ ErrorOr<Clock::time_point> SDLVideoPlayer::RenderNextFrame( } void SDLVideoPlayer::Present() { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); SDL_RenderPresent(renderer_); } @@ -201,5 +201,5 @@ uint32_t SDLVideoPlayer::GetSDLPixelFormat(const AVFrame& picture) { return SDL_PIXELFORMAT_UNKNOWN; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h index 915881eadc2..24b3496ccc0 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h @@ -5,20 +5,23 @@ #ifndef CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_ #define CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_ +#include <string> + #include "cast/standalone_receiver/sdl_player_base.h" +namespace openscreen { namespace cast { -namespace streaming { // Consumes frames from a Receiver, decodes them, and renders them to a // SDL_Renderer. -class SDLVideoPlayer : public SDLPlayerBase { +class SDLVideoPlayer final : public SDLPlayerBase { public: // |error_callback| is run only if a fatal error occurs, at which point the // player has halted and set |error_status()|. - SDLVideoPlayer(openscreen::platform::ClockNowFunctionPtr now_function, - openscreen::platform::TaskRunner* task_runner, + SDLVideoPlayer(ClockNowFunctionPtr now_function, + TaskRunner* task_runner, Receiver* receiver, + const std::string& codec_name, SDL_Renderer* renderer, std::function<void()> error_callback); @@ -32,7 +35,7 @@ class SDLVideoPlayer : public SDLPlayerBase { // Uploads the decoded picture in |frame| to a SDL texture and draws it using // the SDL |renderer_|. - openscreen::ErrorOr<openscreen::platform::Clock::time_point> RenderNextFrame( + ErrorOr<Clock::time_point> RenderNextFrame( const SDLPlayerBase::PresentableFrame& frame) final; // Makes whatever is currently drawn to the SDL |renderer_| be presented @@ -50,7 +53,7 @@ class SDLVideoPlayer : public SDLPlayerBase { SDLTextureUniquePtr texture_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc new file mode 100644 index 00000000000..cf0990cd5ae --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc @@ -0,0 +1,99 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/streaming_playback_controller.h" + +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +#include "cast/standalone_receiver/sdl_audio_player.h" +#include "cast/standalone_receiver/sdl_glue.h" +#include "cast/standalone_receiver/sdl_video_player.h" +#else +#include "cast/standalone_receiver/dummy_player.h" +#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + +#include "util/trace_logging.h" + +namespace openscreen { +namespace cast { + +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +StreamingPlaybackController::StreamingPlaybackController( + TaskRunner* task_runner, + StreamingPlaybackController::Client* client) + : task_runner_(task_runner), + client_(client), + sdl_event_loop_(task_runner_, [this] { + client_->OnPlaybackError(this, + Error{Error::Code::kOperationCancelled, + std::string("SDL event loop closed.")}); + }) { + OSP_DCHECK(task_runner_ != nullptr); + OSP_DCHECK(client_ != nullptr); + constexpr int kDefaultWindowWidth = 1280; + constexpr int kDefaultWindowHeight = 720; + window_ = MakeUniqueSDLWindow( + "Cast Streaming Receiver Demo", + SDL_WINDOWPOS_UNDEFINED /* initial X position */, + SDL_WINDOWPOS_UNDEFINED /* initial Y position */, kDefaultWindowWidth, + kDefaultWindowHeight, SDL_WINDOW_RESIZABLE); + OSP_CHECK(window_) << "Failed to create SDL window: " << SDL_GetError(); + renderer_ = MakeUniqueSDLRenderer(window_.get(), -1, 0); + OSP_CHECK(renderer_) << "Failed to create SDL renderer: " << SDL_GetError(); +} +#else +StreamingPlaybackController::StreamingPlaybackController( + TaskRunner* task_runner, + StreamingPlaybackController::Client* client) + : task_runner_(task_runner), client_(client) { + OSP_DCHECK(task_runner_ != nullptr); + OSP_DCHECK(client_ != nullptr); +} +#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + +// TODO(jophba): add async tracing to streaming implementation for exposing +// how long the OFFER/ANSWER and receiver startup takes. +void StreamingPlaybackController::OnNegotiated( + const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers receivers) { + TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + if (receivers.audio) { + audio_player_ = std::make_unique<SDLAudioPlayer>( + &Clock::now, task_runner_, receivers.audio->receiver, + receivers.audio->selected_stream.stream.codec_name, [this] { + client_->OnPlaybackError(this, audio_player_->error_status()); + }); + } + if (receivers.video) { + video_player_ = std::make_unique<SDLVideoPlayer>( + &Clock::now, task_runner_, receivers.video->receiver, + receivers.video->selected_stream.stream.codec_name, renderer_.get(), + [this] { + client_->OnPlaybackError(this, video_player_->error_status()); + }); + } +#else + if (receivers.audio) { + audio_player_ = std::make_unique<DummyPlayer>(receivers.audio->receiver); + } + + if (receivers.video) { + video_player_ = std::make_unique<DummyPlayer>(receivers.video->receiver); + } +#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +} + +void StreamingPlaybackController::OnConfiguredReceiversDestroyed( + const ReceiverSession* session) { + audio_player_.reset(); + video_player_.reset(); +} + +void StreamingPlaybackController::OnError(const ReceiverSession* session, + Error error) { + client_->OnPlaybackError(this, error); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h new file mode 100644 index 00000000000..72f26a78c70 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h @@ -0,0 +1,68 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_ +#define CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_ + +#include <memory> + +#include "cast/streaming/receiver_session.h" +#include "platform/impl/task_runner.h" + +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +#include "cast/standalone_receiver/sdl_audio_player.h" +#include "cast/standalone_receiver/sdl_glue.h" +#include "cast/standalone_receiver/sdl_video_player.h" +#else +#include "cast/standalone_receiver/dummy_player.h" +#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + +namespace openscreen { +namespace cast { + +class StreamingPlaybackController final : public ReceiverSession::Client { + public: + class Client { + public: + virtual void OnPlaybackError(StreamingPlaybackController* controller, + Error error) = 0; + }; + + StreamingPlaybackController(TaskRunner* task_runner, + StreamingPlaybackController::Client* client); + + // ReceiverSession::Client overrides. + void OnNegotiated(const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers receivers) override; + + void OnConfiguredReceiversDestroyed(const ReceiverSession* session) override; + + void OnError(const ReceiverSession* session, Error error) override; + + private: + TaskRunner* const task_runner_; + StreamingPlaybackController::Client* client_; + +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + // NOTE: member ordering is important, since the sub systems must be + // first-constructed, last-destroyed. Make sure any new SDL related + // members are added below the sub systems. + const ScopedSDLSubSystem<SDL_INIT_AUDIO> sdl_audio_sub_system_; + const ScopedSDLSubSystem<SDL_INIT_VIDEO> sdl_video_sub_system_; + const SDLEventLoopProcessor sdl_event_loop_; + + SDLWindowUniquePtr window_; + SDLRendererUniquePtr renderer_; + std::unique_ptr<SDLAudioPlayer> audio_player_; + std::unique_ptr<SDLVideoPlayer> video_player_; +#else + std::unique_ptr<DummyPlayer> audio_player_; + std::unique_ptr<DummyPlayer> video_player_; +#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn b/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn new file mode 100644 index 00000000000..85e0f1b047b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn @@ -0,0 +1,44 @@ +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//build/config/external_libraries.gni") +import("//build_overrides/build.gni") + +# Define the executable target only when the build is configured to use the +# standalone platform implementation; since this is itself a standalone +# application. +if (!build_with_chromium) { + executable("cast_sender") { + sources = [ + "main.cc", + ] + deps = [ + "../../platform", + "../../util", + "../streaming:sender", + ] + + defines = [] + include_dirs = [] + lib_dirs = [] + libs = [] + if (have_ffmpeg && have_libopus && have_libvpx) { + defines += [ "CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS" ] + sources += [ + "ffmpeg_glue.cc", + "ffmpeg_glue.h", + "simulated_capturer.cc", + "simulated_capturer.h", + "streaming_opus_encoder.cc", + "streaming_opus_encoder.h", + "streaming_vp8_encoder.cc", + "streaming_vp8_encoder.h", + ] + include_dirs += + ffmpeg_include_dirs + libopus_include_dirs + libvpx_include_dirs + lib_dirs += ffmpeg_lib_dirs + libopus_lib_dirs + libvpx_lib_dirs + libs += ffmpeg_libs + libopus_libs + libvpx_libs + } + } +} diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS b/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS new file mode 100644 index 00000000000..3074fec20e0 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS @@ -0,0 +1,8 @@ +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +include_rules = [ + '+cast', + '+platform/impl', +] diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc new file mode 100644 index 00000000000..c271f80b188 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/ffmpeg_glue.h" + +#include "util/logging.h" + +namespace openscreen { +namespace cast { +namespace internal { + +AVFormatContext* CreateAVFormatContextForFile(const char* path) { + AVFormatContext* format_context = nullptr; + int result = avformat_open_input(&format_context, path, nullptr, nullptr); + if (result < 0) { + OSP_LOG_ERROR << "Cannot open " << path << ": " << av_err2str(result); + return nullptr; + } + result = avformat_find_stream_info(format_context, nullptr); + if (result < 0) { + avformat_close_input(&format_context); + OSP_LOG_ERROR << "Cannot find stream info in " << path << ": " + << av_err2str(result); + return nullptr; + } + return format_context; +} + +} // namespace internal +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h new file mode 100644 index 00000000000..28084c0d7d8 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h @@ -0,0 +1,71 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_ +#define CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_ + +extern "C" { +#include <libavcodec/avcodec.h> +#include <libavformat/avformat.h> +#include <libavutil/channel_layout.h> +#include <libavutil/common.h> +#include <libavutil/imgutils.h> +#include <libavutil/mathematics.h> +#include <libavutil/pixfmt.h> +#include <libavutil/samplefmt.h> +#include <libswresample/swresample.h> +} + +#include <memory> +#include <utility> + +namespace openscreen { +namespace cast { + +namespace internal { + +// Convenience allocator for a new AVFormatContext, given a file |path|. Returns +// nullptr on error. Note: MakeUniqueAVFormatContext() is the public API. +AVFormatContext* CreateAVFormatContextForFile(const char* path); + +} // namespace internal + +// Macro that, for an AVFoo, generates code for: +// +// using FooUniquePtr = std::unique_ptr<Foo, FooFreer>; +// FooUniquePtr MakeUniqueFoo(...args...); +#define DEFINE_AV_UNIQUE_PTR(name, create_func, free_func) \ + namespace internal { \ + struct name##Freer { \ + void operator()(name* obj) const { \ + if (obj) { \ + free_func(&obj); \ + } \ + } \ + }; \ + } \ + \ + using name##UniquePtr = std::unique_ptr<name, internal::name##Freer>; \ + \ + template <typename... Args> \ + name##UniquePtr MakeUnique##name(Args&&... args) { \ + return name##UniquePtr(create_func(std::forward<Args>(args)...)); \ + } + +DEFINE_AV_UNIQUE_PTR(AVFormatContext, + ::openscreen::cast::internal::CreateAVFormatContextForFile, + avformat_close_input); +DEFINE_AV_UNIQUE_PTR(AVCodecContext, + avcodec_alloc_context3, + avcodec_free_context); +DEFINE_AV_UNIQUE_PTR(AVPacket, av_packet_alloc, av_packet_free); +DEFINE_AV_UNIQUE_PTR(AVFrame, av_frame_alloc, av_frame_free); +DEFINE_AV_UNIQUE_PTR(SwrContext, swr_alloc, swr_free); + +#undef DEFINE_AV_UNIQUE_PTR + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc new file mode 100644 index 00000000000..5d6c1a2c007 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc @@ -0,0 +1,480 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <getopt.h> + +#include <chrono> // NOLINT +#include <cinttypes> +#include <csignal> +#include <cstdio> +#include <cstring> +#include <sstream> + +#include "cast/streaming/constants.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/sender.h" +#include "cast/streaming/sender_packet_router.h" +#include "cast/streaming/session_config.h" +#include "cast/streaming/ssrc.h" +#include "platform/api/time.h" +#include "platform/base/error.h" +#include "platform/base/ip_address.h" +#include "platform/impl/logging.h" +#include "platform/impl/platform_client_posix.h" +#include "platform/impl/task_runner.h" +#include "platform/impl/text_trace_logging_platform.h" +#include "util/alarm.h" + +#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) +#include "cast/standalone_sender/simulated_capturer.h" +#include "cast/standalone_sender/streaming_opus_encoder.h" +#include "cast/standalone_sender/streaming_vp8_encoder.h" +#endif + +namespace openscreen { +namespace cast { +namespace { + +using std::chrono::duration_cast; +using std::chrono::milliseconds; +using std::chrono::seconds; + +//////////////////////////////////////////////////////////////////////////////// +// Sender Configuration +// +// The values defined here are constants that correspond to the standalone Cast +// Receiver app. In a production environment, these should ABSOLUTELY NOT be +// fixed! Instead a sender↔receiver OFFER/ANSWER exchange should establish them. + +// In a production environment, this would start-out at some initial value +// appropriate to the networking environment, and then be adjusted by the +// application as: 1) the TYPE of the content changes (interactive, low-latency +// versus smooth, higher-latency buffered video watching); and 2) the networking +// environment reliability changes. +constexpr milliseconds kTargetPlayoutDelay = kDefaultTargetPlayoutDelay; + +const SessionConfig kSampleAudioAnswerConfig{ + /* .sender_ssrc = */ 1, + /* .receiver_ssrc = */ 2, + /* .rtp_timebase = */ 48000, + /* .channels = */ 2, + /* .target_playout_delay */ kTargetPlayoutDelay, + /* .aes_secret_key = */ + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f}, + /* .aes_iv_mask = */ + {0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80, 0x70, 0x60, 0x50, 0x40, + 0x30, 0x20, 0x10, 0x00}, +}; + +const SessionConfig kSampleVideoAnswerConfig{ + /* .sender_ssrc = */ 50001, + /* .receiver_ssrc = */ 50002, + /* .rtp_timebase = */ static_cast<int>(kVideoTimebase::den), + /* .channels = */ 1, + /* .target_playout_delay */ kTargetPlayoutDelay, + /* .aes_secret_key = */ + {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f}, + /* .aes_iv_mask = */ + {0xf1, 0xe1, 0xd1, 0xc1, 0xb1, 0xa1, 0x91, 0x81, 0x71, 0x61, 0x51, 0x41, + 0x31, 0x21, 0x11, 0x01}, +}; + +// End of Sender Configuration. +//////////////////////////////////////////////////////////////////////////////// + +// What is the minimum amount of bandwidth required? +constexpr int kMinRequiredBitrate = 384 << 10; // 384 kbps. + +// What is the default maximum bitrate setting? +constexpr int kDefaultMaxBitrate = 5 << 20; // 5 Mbps. + +#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) + +// Above what available bandwidth should the high-quality audio bitrate be used? +constexpr int kHighBandwidthThreshold = 5 << 20; // 5 Mbps. + +// How often should the file position (media timestamp) be updated on the +// console? +constexpr milliseconds kConsoleUpdateInterval{100}; + +// How often should the congestion control logic re-evaluate the target encode +// bitrates? +constexpr milliseconds kCongestionCheckInterval{500}; + +// Plays the media file at a given path over and over again, transcoding and +// streaming its audio/video. +class LoopingFileSender final : public SimulatedAudioCapturer::Client, + public SimulatedVideoCapturer::Client { + public: + LoopingFileSender(TaskRunner* task_runner, + const char* path, + const IPEndpoint& remote_endpoint, + int max_bitrate, + bool use_android_rtp_hack) + : env_(&Clock::now, task_runner, IPEndpoint{IPAddress(), 0}), + path_(path), + packet_router_(&env_), + max_bitrate_(max_bitrate), + audio_sender_(&env_, + &packet_router_, + kSampleAudioAnswerConfig, + use_android_rtp_hack + ? RtpPayloadType::kAudioHackForAndroidTV + : RtpPayloadType::kAudioOpus), + video_sender_(&env_, + &packet_router_, + kSampleVideoAnswerConfig, + use_android_rtp_hack + ? RtpPayloadType::kVideoHackForAndroidTV + : RtpPayloadType::kVideoVp8), + audio_encoder_(kSampleAudioAnswerConfig.channels, + StreamingOpusEncoder::kDefaultCastAudioFramesPerSecond, + &audio_sender_), + video_encoder_(StreamingVp8Encoder::Parameters{}, + env_.task_runner(), + &video_sender_), + next_task_(env_.now_function(), env_.task_runner()), + console_update_task_(env_.now_function(), env_.task_runner()) { + env_.set_remote_endpoint(remote_endpoint); + OSP_LOG_INFO << "Streaming to " << remote_endpoint << "..."; + + if (use_android_rtp_hack) { + OSP_LOG_INFO << "Using RTP payload types for older Android TV receivers."; + } + + OSP_LOG_INFO << "Max allowed media bitrate (audio + video) will be " + << max_bitrate_; + bandwidth_being_utilized_ = max_bitrate_ / 2; + UpdateEncoderBitrates(); + + next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately); + } + + ~LoopingFileSender() final = default; + + private: + void UpdateEncoderBitrates() { + if (bandwidth_being_utilized_ >= kHighBandwidthThreshold) { + audio_encoder_.UseHighQuality(); + } else { + audio_encoder_.UseStandardQuality(); + } + video_encoder_.SetTargetBitrate(bandwidth_being_utilized_ - + audio_encoder_.GetBitrate()); + } + + void ControlForNetworkCongestion() { + bandwidth_estimate_ = packet_router_.ComputeNetworkBandwidth(); + if (bandwidth_estimate_ > 0) { + // Don't ever try to use *all* of the network bandwidth! However, don't go + // below the absolute minimum requirement either. + constexpr double kGoodNetworkCitizenFactor = 0.8; + const int usable_bandwidth = std::max<int>( + kGoodNetworkCitizenFactor * bandwidth_estimate_, kMinRequiredBitrate); + + // See "congestion control" discussion in the class header comments for + // BandwidthEstimator. + if (usable_bandwidth > bandwidth_being_utilized_) { + constexpr double kConservativeIncrease = 1.1; + bandwidth_being_utilized_ = + std::min<int>(bandwidth_being_utilized_ * kConservativeIncrease, + usable_bandwidth); + } else { + bandwidth_being_utilized_ = usable_bandwidth; + } + + // Repsect the user's maximum bitrate setting. + bandwidth_being_utilized_ = + std::min(bandwidth_being_utilized_, max_bitrate_); + + UpdateEncoderBitrates(); + } else { + // There is no current bandwidth estimate. So, nothing should be adjusted. + } + + next_task_.ScheduleFromNow([this] { ControlForNetworkCongestion(); }, + kCongestionCheckInterval); + } + + void SendFileAgain() { + OSP_LOG_INFO << "Sending " << path_ << " (starts in one second)..."; + + OSP_DCHECK_EQ(num_capturers_running_, 0); + num_capturers_running_ = 2; + capture_start_time_ = latest_frame_time_ = env_.now() + seconds(1); + audio_capturer_.emplace(&env_, path_, audio_encoder_.num_channels(), + audio_encoder_.sample_rate(), capture_start_time_, + this); + video_capturer_.emplace(&env_, path_, capture_start_time_, this); + + next_task_.ScheduleFromNow([this] { ControlForNetworkCongestion(); }, + kCongestionCheckInterval); + console_update_task_.Schedule([this] { UpdateStatusOnConsole(); }, + capture_start_time_); + } + + void OnAudioData(const float* interleaved_samples, + int num_samples, + Clock::time_point capture_time) final { + latest_frame_time_ = std::max(capture_time, latest_frame_time_); + audio_encoder_.EncodeAndSend(interleaved_samples, num_samples, + capture_time); + } + + void OnVideoFrame(const AVFrame& av_frame, + Clock::time_point capture_time) final { + latest_frame_time_ = std::max(capture_time, latest_frame_time_); + StreamingVp8Encoder::VideoFrame frame{}; + frame.width = av_frame.width - av_frame.crop_left - av_frame.crop_right; + frame.height = av_frame.height - av_frame.crop_top - av_frame.crop_bottom; + frame.yuv_planes[0] = av_frame.data[0] + av_frame.crop_left + + av_frame.linesize[0] * av_frame.crop_top; + frame.yuv_planes[1] = av_frame.data[1] + av_frame.crop_left / 2 + + av_frame.linesize[1] * av_frame.crop_top / 2; + frame.yuv_planes[2] = av_frame.data[2] + av_frame.crop_left / 2 + + av_frame.linesize[2] * av_frame.crop_top / 2; + for (int i = 0; i < 3; ++i) { + frame.yuv_strides[i] = av_frame.linesize[i]; + } + // TODO(miu): Add performance metrics visual overlay (based on Stats + // callback). + video_encoder_.EncodeAndSend(frame, capture_time, {}); + } + + void UpdateStatusOnConsole() { + const Clock::duration elapsed = latest_frame_time_ - capture_start_time_; + const auto seconds_part = duration_cast<seconds>(elapsed); + const auto millis_part = + duration_cast<milliseconds>(elapsed - seconds_part); + // The control codes here attempt to erase the current line the cursor is + // on, and then print out the updated status text. If the terminal does not + // support simple ANSI escape codes, the following will still work, but + // there might sometimes be old status lines not getting erased (i.e., just + // partially overwritten). + fprintf(stdout, + "\r\x1b[2K\rAt %01" PRId64 + ".%03ds in file (est. network bandwidth: %d kbps). ", + static_cast<int64_t>(seconds_part.count()), + static_cast<int>(millis_part.count()), bandwidth_estimate_ / 1024); + fflush(stdout); + + console_update_task_.ScheduleFromNow([this] { UpdateStatusOnConsole(); }, + kConsoleUpdateInterval); + } + + void OnEndOfFile(SimulatedCapturer* capturer) final { + OSP_LOG_INFO << "The " << ToTrackName(capturer) + << " capturer has reached the end of the media stream."; + --num_capturers_running_; + if (num_capturers_running_ == 0) { + console_update_task_.Cancel(); + next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately); + } + } + + void OnError(SimulatedCapturer* capturer, std::string message) final { + OSP_LOG_ERROR << "The " << ToTrackName(capturer) + << " has failed: " << message; + --num_capturers_running_; + // If both fail, the application just pauses. This accounts for things like + // "file not found" errors. However, if only one track fails, then keep + // going. + } + + const char* ToTrackName(SimulatedCapturer* capturer) const { + const char* which; + if (capturer == &*audio_capturer_) { + which = "audio"; + } else if (capturer == &*video_capturer_) { + which = "video"; + } else { + OSP_NOTREACHED(); + which = ""; + } + return which; + } + + // Holds the required injected dependencies (clock, task runner) used for Cast + // Streaming, and owns the UDP socket over which all communications occur with + // the remote's Receivers. + Environment env_; + + // The path to the media file to stream over and over. + const char* const path_; + + // The packet router allows both the Audio Sender and the Video Sender to + // share the same UDP socket. + SenderPacketRouter packet_router_; + + const int max_bitrate_; // Passed by the user on the command line. + int bandwidth_estimate_ = 0; + int bandwidth_being_utilized_; + + Sender audio_sender_; + Sender video_sender_; + + StreamingOpusEncoder audio_encoder_; + StreamingVp8Encoder video_encoder_; + + int num_capturers_running_ = 0; + Clock::time_point capture_start_time_{}; + Clock::time_point latest_frame_time_{}; + absl::optional<SimulatedAudioCapturer> audio_capturer_; + absl::optional<SimulatedVideoCapturer> video_capturer_; + + Alarm next_task_; + Alarm console_update_task_; +}; + +#endif // defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) + +IPEndpoint GetDefaultEndpoint() { + return IPEndpoint{IPAddress::kV4LoopbackAddress, kDefaultCastStreamingPort}; +} + +void LogUsage(const char* argv0) { + const char kUsageMessageFormat[] = R"( + usage: %s <options> <media_file> + + --remote=addr[:port] + Specify the destination (e.g., 192.168.1.22:9999 or [::1]:12345). + + Default if not set: %s + + --max-bitrate=N + Specifies the maximum bits per second for the media streams. + + Default if not set: %d. + + --android-hack: + Use the wrong RTP payload types, for compatibility with older Android + TV receivers. + + --tracing: Enable performance tracing logging. + )"; + // TODO(https://crbug.com/openscreen/122): absl::StreamFormat() would be much + // cleaner here. For example, all the code here could be replaced with: + // + // OSP_LOG_ERROR << absl::StreamFormat(kUsageMessageFormat, argv0, + // absl::FromatStreamed(endpoint), + // kDefaultMaxBitrate); + std::string endpoint; + { + std::ostringstream oss; + oss << GetDefaultEndpoint(); + endpoint = oss.str(); + } + const int formatted_length_with_nul = + snprintf(nullptr, 0, kUsageMessageFormat, argv0, endpoint.c_str(), + kDefaultMaxBitrate) + + 1; + const std::unique_ptr<char[]> usage_cstr(new char[formatted_length_with_nul]); + snprintf(usage_cstr.get(), formatted_length_with_nul, kUsageMessageFormat, + argv0, endpoint.c_str(), kDefaultMaxBitrate); + OSP_LOG_ERROR << usage_cstr.get(); +} + +int StandaloneSenderMain(int argc, char* argv[]) { + SetLogLevel(LogLevel::kInfo); + + const struct option argument_options[] = { + {"remote", required_argument, nullptr, 'r'}, + {"max-bitrate", required_argument, nullptr, 'm'}, + {"android-hack", no_argument, nullptr, 'a'}, + {"tracing", no_argument, nullptr, 't'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}}; + + IPEndpoint remote_endpoint = GetDefaultEndpoint(); + [[maybe_unused]] bool use_android_rtp_hack = false; + [[maybe_unused]] int max_bitrate = kDefaultMaxBitrate; + std::unique_ptr<TextTraceLoggingPlatform> trace_logger; + int ch = -1; + while ((ch = getopt_long(argc, argv, "r:ath", argument_options, nullptr)) != + -1) { + switch (ch) { + case 'r': { + const ErrorOr<IPEndpoint> parsed_endpoint = IPEndpoint::Parse(optarg); + if (parsed_endpoint.is_value()) { + remote_endpoint = parsed_endpoint.value(); + } else { + const ErrorOr<IPAddress> parsed_address = IPAddress::Parse(optarg); + if (parsed_address.is_value()) { + remote_endpoint.address = parsed_address.value(); + } else { + OSP_LOG_ERROR << "Invalid --remote specified: " << optarg; + LogUsage(argv[0]); + return 1; + } + } + break; + } + case 'm': + max_bitrate = atoi(optarg); + if (max_bitrate < kMinRequiredBitrate) { + OSP_LOG_ERROR << "Invalid --max-bitrate specified: " << optarg + << " is less than " << kMinRequiredBitrate; + LogUsage(argv[0]); + return 1; + } + break; + case 'a': + use_android_rtp_hack = true; + break; + case 't': + trace_logger = std::make_unique<TextTraceLoggingPlatform>(); + break; + case 'h': + LogUsage(argv[0]); + return 1; + } + } + + // The last command line argument must be the path to the file. + const char* path = nullptr; + if (optind == (argc - 1)) { + path = argv[optind]; + } + + if (!path || !remote_endpoint.port) { + LogUsage(argv[0]); + return 1; + } + +#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) + + auto* const task_runner = new TaskRunnerImpl(&Clock::now); + PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}, + std::unique_ptr<TaskRunnerImpl>(task_runner)); + + { + LoopingFileSender file_sender(task_runner, path, remote_endpoint, + max_bitrate, use_android_rtp_hack); + // Run the event loop until SIGINT (e.g., CTRL-C at the console) or SIGTERM + // are signaled. + task_runner->RunUntilSignaled(); + } + + PlatformClientPosix::ShutDown(); + +#else + + OSP_LOG_INFO + << "It compiled! However, you need to configure the build to point to " + "external libraries in order to build a useful app."; + +#endif // defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) + + return 0; +} + +} // namespace +} // namespace cast +} // namespace openscreen + +int main(int argc, char* argv[]) { + return openscreen::cast::StandaloneSenderMain(argc, argv); +} diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc new file mode 100644 index 00000000000..327c76f8376 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc @@ -0,0 +1,378 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/simulated_capturer.h" + +#include <algorithm> +#include <chrono> // NOLINT +#include <ratio> // NOLINT +#include <sstream> +#include <thread> // NOLINT + +#include "cast/streaming/environment.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +using openscreen::operator<<; // To pretty-print chrono values. + +namespace { +// Threshold at which a warning about media pausing should be logged. +constexpr std::chrono::seconds kPauseWarningThreshold{3}; +} // namespace + +SimulatedCapturer::Observer::~Observer() = default; + +SimulatedCapturer::SimulatedCapturer(Environment* environment, + const char* path, + AVMediaType media_type, + Clock::time_point start_time, + Observer* observer) + : format_context_(MakeUniqueAVFormatContext(path)), + media_type_(media_type), + start_time_(start_time), + observer_(observer), + packet_(MakeUniqueAVPacket()), + decoded_frame_(MakeUniqueAVFrame()), + next_task_(environment->now_function(), environment->task_runner()) { + OSP_DCHECK(observer_); + + if (!format_context_) { + OnError("MakeUniqueAVFormatContext", AVERROR_UNKNOWN); + return; // Capturer is halted (unable to start). + } + + AVCodec* codec; + const int stream_result = av_find_best_stream(format_context_.get(), + media_type_, -1, -1, &codec, 0); + if (stream_result < 0) { + OnError("av_find_best_stream", stream_result); + return; // Capturer is halted (unable to start). + } + + stream_index_ = stream_result; + decoder_context_ = MakeUniqueAVCodecContext(codec); + if (!decoder_context_) { + OnError("MakeUniqueAVCodecContext", AVERROR_BUG); + return; // Capturer is halted (unable to start). + } + // This should also be 16 or less, since the encoder implementations emit + // warnings about too many encode threads. FFMPEG's VP8 implementation + // actually silently freezes if this is 10 or more. Thus, 8 is used for the + // max here, just to be safe. + decoder_context_->thread_count = + std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8); + const int params_result = avcodec_parameters_to_context( + decoder_context_.get(), + format_context_->streams[stream_index_]->codecpar); + if (params_result < 0) { + OnError("avcodec_parameters_to_context", params_result); + return; // Capturer is halted (unable to start). + } + SetAdditionalDecoderParameters(decoder_context_.get()); + + const int open_result = avcodec_open2(decoder_context_.get(), codec, nullptr); + if (open_result < 0) { + OnError("avcodec_open2", open_result); + return; // Capturer is halted (unable to start). + } + + next_task_.Schedule([this] { StartDecodingNextFrame(); }, + Alarm::kImmediately); +} + +SimulatedCapturer::~SimulatedCapturer() = default; + +void SimulatedCapturer::SetAdditionalDecoderParameters( + AVCodecContext* decoder_context) {} + +absl::optional<Clock::duration> SimulatedCapturer::ProcessDecodedFrame( + const AVFrame& frame) { + return Clock::duration::zero(); +} + +void SimulatedCapturer::OnError(const char* function_name, int av_errnum) { + // Make a human-readable string from the libavcodec error. + std::ostringstream error; + error << "For " << av_get_media_type_string(media_type_) << ", " + << function_name << " returned error: " << av_err2str(av_errnum); + + // Deliver the error notification in a separate task since this method might + // have been called from the constructor. + next_task_.Schedule( + [this, error_string = error.str()] { + observer_->OnError(this, error_string); + // Capturer is now halted. + }, + Alarm::kImmediately); +} + +// static +Clock::duration SimulatedCapturer::ToApproximateClockDuration( + int64_t ticks, + const AVRational& time_base) { + return Clock::duration(av_rescale_q( + ticks, time_base, + AVRational{Clock::duration::period::num, Clock::duration::period::den})); +} + +void SimulatedCapturer::StartDecodingNextFrame() { + const int read_frame_result = + av_read_frame(format_context_.get(), packet_.get()); + if (read_frame_result < 0) { + if (read_frame_result == AVERROR_EOF) { + // Insert a "flush request" into the decoder's pipeline, which will + // signal an EOF in ConsumeNextDecodedFrame() later. + avcodec_send_packet(decoder_context_.get(), nullptr); + next_task_.Schedule([this] { ConsumeNextDecodedFrame(); }, + Alarm::kImmediately); + } else { + // All other error codes are fatal. + OnError("av_read_frame", read_frame_result); + // Capturer is now halted. + } + return; + } + + if (packet_->stream_index != stream_index_) { + av_packet_unref(packet_.get()); + next_task_.Schedule([this] { StartDecodingNextFrame(); }, + Alarm::kImmediately); + return; + } + + const int send_packet_result = + avcodec_send_packet(decoder_context_.get(), packet_.get()); + av_packet_unref(packet_.get()); + if (send_packet_result < 0) { + // Note: AVERROR(EAGAIN) is also treated as fatal here because + // avcodec_receive_frame() will be called repeatedly until its result code + // indicates avcodec_send_packet() must be called again. + OnError("avcodec_send_packet", send_packet_result); + return; // Capturer is now halted. + } + + next_task_.Schedule([this] { ConsumeNextDecodedFrame(); }, + Alarm::kImmediately); +} + +void SimulatedCapturer::ConsumeNextDecodedFrame() { + const int receive_frame_result = + avcodec_receive_frame(decoder_context_.get(), decoded_frame_.get()); + if (receive_frame_result < 0) { + switch (receive_frame_result) { + case AVERROR(EAGAIN): + // This result code, according to libavcodec documentation, means more + // data should be fed into the decoder (e.g., interframe dependencies). + next_task_.Schedule([this] { StartDecodingNextFrame(); }, + Alarm::kImmediately); + return; + case AVERROR_EOF: + observer_->OnEndOfFile(this); + return; // Capturer is now halted. + default: + OnError("avcodec_receive_frame", receive_frame_result); + return; // Capturer is now halted. + } + } + + const Clock::duration frame_timestamp = ToApproximateClockDuration( + decoded_frame_->best_effort_timestamp, + format_context_->streams[stream_index_]->time_base); + if (last_frame_timestamp_) { + const Clock::duration delta = frame_timestamp - *last_frame_timestamp_; + if (delta <= Clock::duration::zero()) { + OSP_LOG_WARN << "Dropping " << av_get_media_type_string(media_type_) + << " frame with illegal timestamp (delta from last frame: " + << delta << "). Bad media file!"; + av_frame_unref(decoded_frame_.get()); + next_task_.Schedule([this] { ConsumeNextDecodedFrame(); }, + Alarm::kImmediately); + return; + } else if (delta >= kPauseWarningThreshold) { + OSP_LOG_INFO << "For " << av_get_media_type_string(media_type_) + << ", encountered a media pause (" << delta + << ") in the file."; + } + } + last_frame_timestamp_ = frame_timestamp; + + Clock::time_point capture_time = start_time_ + frame_timestamp; + const auto delay_adjustment_or_null = ProcessDecodedFrame(*decoded_frame_); + if (!delay_adjustment_or_null) { + av_frame_unref(decoded_frame_.get()); + return; // Stop. Fatal error occurred. + } + capture_time += *delay_adjustment_or_null; + + next_task_.Schedule( + [this, capture_time] { + DeliverDataToClient(*decoded_frame_, capture_time); + av_frame_unref(decoded_frame_.get()); + ConsumeNextDecodedFrame(); + }, + capture_time); +} + +SimulatedAudioCapturer::Client::~Client() = default; + +SimulatedAudioCapturer::SimulatedAudioCapturer(Environment* environment, + const char* path, + int num_channels, + int sample_rate, + Clock::time_point start_time, + Client* client) + : SimulatedCapturer(environment, + path, + AVMEDIA_TYPE_AUDIO, + start_time, + client), + num_channels_(num_channels), + sample_rate_(sample_rate), + client_(client), + resampler_(MakeUniqueSwrContext()) { + OSP_DCHECK_GT(num_channels_, 0); + OSP_DCHECK_GT(sample_rate_, 0); +} + +SimulatedAudioCapturer::~SimulatedAudioCapturer() { + if (swr_is_initialized(resampler_.get())) { + swr_close(resampler_.get()); + } +} + +bool SimulatedAudioCapturer::EnsureResamplerIsInitializedFor( + const AVFrame& frame) { + if (swr_is_initialized(resampler_.get())) { + if (input_sample_format_ == static_cast<AVSampleFormat>(frame.format) && + input_sample_rate_ == frame.sample_rate && + input_channel_layout_ == frame.channel_layout) { + return true; + } + + // Note: Usually, the resampler should be flushed before being destroyed. + // However, because of the way SimulatedAudioCapturer uses the API, only one + // audio sample should be dropped in the worst case. Log what's being + // dropped, just in case libswresample is behaving differently than + // expected. + const std::chrono::microseconds amount( + swr_get_delay(resampler_.get(), std::micro::den)); + OSP_LOG_INFO << "Discarding " << amount + << " of audio from the resampler before re-init."; + } + + input_sample_format_ = AV_SAMPLE_FMT_NONE; + + // Create a fake output frame to hold the output audio parameters, because the + // resampler API is weird that way. + const auto fake_output_frame = MakeUniqueAVFrame(); + fake_output_frame->channel_layout = + av_get_default_channel_layout(num_channels_); + fake_output_frame->format = AV_SAMPLE_FMT_FLT; + fake_output_frame->sample_rate = sample_rate_; + const int config_result = + swr_config_frame(resampler_.get(), fake_output_frame.get(), &frame); + if (config_result < 0) { + OnError("swr_config_frame", config_result); + return false; // Capturer is now halted. + } + + const int init_result = swr_init(resampler_.get()); + if (init_result < 0) { + OnError("swr_init", init_result); + return false; // Capturer is now halted. + } + + input_sample_format_ = static_cast<AVSampleFormat>(frame.format); + input_sample_rate_ = frame.sample_rate; + input_channel_layout_ = frame.channel_layout; + return true; +} + +absl::optional<Clock::duration> SimulatedAudioCapturer::ProcessDecodedFrame( + const AVFrame& frame) { + if (!EnsureResamplerIsInitializedFor(frame)) { + return absl::nullopt; + } + + const int64_t num_leftover_input_samples = + swr_get_delay(resampler_.get(), input_sample_rate_); + OSP_DCHECK_GE(num_leftover_input_samples, 0); + const Clock::duration capture_time_adjustment = -ToApproximateClockDuration( + num_leftover_input_samples, AVRational{1, input_sample_rate_}); + + const int64_t num_output_samples_desired = + av_rescale_rnd(num_leftover_input_samples + frame.nb_samples, + sample_rate_, input_sample_rate_, AV_ROUND_ZERO); + OSP_DCHECK_GE(num_output_samples_desired, 0); + resampled_audio_.resize(num_channels_ * num_output_samples_desired); + uint8_t* output_argument[1] = { + reinterpret_cast<uint8_t*>(resampled_audio_.data())}; + const int num_samples_converted_or_error = swr_convert( + resampler_.get(), output_argument, num_output_samples_desired, + const_cast<const uint8_t**>(frame.extended_data), frame.nb_samples); + if (num_samples_converted_or_error < 0) { + resampled_audio_.clear(); + swr_close(resampler_.get()); + OnError("swr_convert", num_samples_converted_or_error); + return absl::nullopt; // Capturer is now halted. + } + resampled_audio_.resize(num_channels_ * num_samples_converted_or_error); + + return capture_time_adjustment; +} + +void SimulatedAudioCapturer::DeliverDataToClient( + const AVFrame& unused, + Clock::time_point capture_time) { + if (resampled_audio_.empty()) { + return; + } + client_->OnAudioData(resampled_audio_.data(), + resampled_audio_.size() / num_channels_, capture_time); + resampled_audio_.clear(); +} + +SimulatedVideoCapturer::Client::~Client() = default; + +SimulatedVideoCapturer::SimulatedVideoCapturer(Environment* environment, + const char* path, + Clock::time_point start_time, + Client* client) + : SimulatedCapturer(environment, + path, + AVMEDIA_TYPE_VIDEO, + start_time, + client), + client_(client) {} + +SimulatedVideoCapturer::~SimulatedVideoCapturer() = default; + +void SimulatedVideoCapturer::SetAdditionalDecoderParameters( + AVCodecContext* decoder_context) { + // Require the I420 planar format for video. + decoder_context->get_format = [](struct AVCodecContext* s, + const enum AVPixelFormat* formats) { + // Return AV_PIX_FMT_YUV420P if it's in the provided list of supported + // formats. Otherwise, return AV_PIX_FMT_NONE. + // + // |formats| is a NONE-terminated array. + for (; *formats != AV_PIX_FMT_NONE; ++formats) { + if (*formats == AV_PIX_FMT_YUV420P) { + break; + } + } + return *formats; + }; +} + +void SimulatedVideoCapturer::DeliverDataToClient( + const AVFrame& frame, + Clock::time_point capture_time) { + client_->OnVideoFrame(frame, capture_time); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h new file mode 100644 index 00000000000..8d32085aa73 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h @@ -0,0 +1,205 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_ +#define CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_ + +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/types/optional.h" +#include "cast/standalone_sender/ffmpeg_glue.h" +#include "platform/api/time.h" +#include "util/alarm.h" + +namespace openscreen { +namespace cast { + +class Environment; + +// Simulates live media capture by demuxing, decoding, and emitting a stream of +// frames from a file at normal (1X) speed. This is a base class containing +// common functionality. Typical usage: Instantiate one SimulatedAudioCapturer +// and one FileVideoStreamCapturer. +class SimulatedCapturer { + public: + // Interface for receiving end-of-stream and fatal error notifications. + class Observer { + public: + // Called once the end of the file has been reached and the |capturer| has + // halted. + virtual void OnEndOfFile(SimulatedCapturer* capturer) = 0; + + // Called if a non-recoverable error occurs and the |capturer| has halted. + virtual void OnError(SimulatedCapturer* capturer, std::string message) = 0; + + protected: + virtual ~Observer(); + }; + + protected: + SimulatedCapturer(Environment* environment, + const char* path, + AVMediaType media_type, + Clock::time_point start_time, + Observer* observer); + + virtual ~SimulatedCapturer(); + + // Optionally overridden, to apply additional decoder context settings before + // avcodec_open2() is called. + virtual void SetAdditionalDecoderParameters(AVCodecContext* decoder_context); + + // Performs any additional processing on the decoded frame (e.g., audio + // resampling), and returns any adjustments to the frame's capture time (e.g., + // to account for any buffering). If a fatal error occurs, absl::nullopt is + // returned. The default implementation does nothing. + // + // Mutating the |decoded_frame| is not allowed. If a subclass implementation + // wants to deliver different data (e.g., resampled audio), it must stash the + // data itself for the next DeliverDataToClient() call. + virtual absl::optional<Clock::duration> ProcessDecodedFrame( + const AVFrame& decoded_frame); + + // Delivers the decoded frame data to the client. + virtual void DeliverDataToClient(const AVFrame& decoded_frame, + Clock::time_point capture_time) = 0; + + // Called when any transient or fatal error occurs, generating an Error and + // scheduling a task to notify the Observer of it soon. + void OnError(const char* what, int av_errnum); + + // Converts the given FFMPEG tick count into an approximate Clock::duration. + static Clock::duration ToApproximateClockDuration( + int64_t ticks, + const AVRational& time_base); + + private: + // Reads the next frame from the file, sends it to the decoder, and schedules + // a future ConsumeNextDecodedFrame() call to continue processing. + void StartDecodingNextFrame(); + + // Receives the next decoded frame and schedules media delivery to the client, + // and/or calls Observer::OnEndOfFile() if there are no more frames in the + // file. + void ConsumeNextDecodedFrame(); + + const AVFormatContextUniquePtr format_context_; + const AVMediaType media_type_; // Audio or Video. + const Clock::time_point start_time_; + Observer* const observer_; + const AVPacketUniquePtr packet_; // Decoder input buffer. + const AVFrameUniquePtr decoded_frame_; // Decoder output frame. + int stream_index_ = -1; // Selected stream from the file. + AVCodecContextUniquePtr decoder_context_; + + // The last frame's stream timestamp. This is used to detect bad stream + // timestamps in the file. + absl::optional<Clock::duration> last_frame_timestamp_; + + // Used to schedule the next task to execute and when it should execute. There + // is only ever one task scheduled/running at any time. + Alarm next_task_; +}; + +// Emits the primary audio stream from a file. +class SimulatedAudioCapturer final : public SimulatedCapturer { + public: + class Client : public SimulatedCapturer::Observer { + public: + // Called to deliver more audio data as |interleaved_samples|, which + // contains |num_samples| tuples (i.e., multiply by the number of channels + // to determine the number of array elements). |capture_time| is used to + // synchronize the play-out of the first audio sample with respect to video + // frames. + virtual void OnAudioData(const float* interleaved_samples, + int num_samples, + Clock::time_point capture_time) = 0; + + protected: + ~Client() override; + }; + + // Constructor: |num_channels| and |sample_rate| specify the required audio + // format. If necessary, audio from the file will be resampled to match the + // required format. + SimulatedAudioCapturer(Environment* environment, + const char* path, + int num_channels, + int sample_rate, + Clock::time_point start_time, + Client* client); + + ~SimulatedAudioCapturer() final; + + private: + // Examines the audio format of the given |frame|, and ensures the resampler + // is initialized to take that as input. + bool EnsureResamplerIsInitializedFor(const AVFrame& frame); + + // Resamples the current |SimulatedCapturer::decoded_frame()| into the + // required output format/channels/rate. The result is stored in + // |resampled_audio_| for the next DeliverDataToClient() call. + absl::optional<Clock::duration> ProcessDecodedFrame( + const AVFrame& decoded_frame) final; + + // Called at the moment Client::OnAudioData() should be called to pass the + // |resampled_audio_|. + void DeliverDataToClient(const AVFrame& decoded_frame, + Clock::time_point capture_time) final; + + const int num_channels_; // Output number of channels. + const int sample_rate_; // Output sample rate. + Client* const client_; + + const SwrContextUniquePtr resampler_; + + // Current resampler input audio parameters. + AVSampleFormat input_sample_format_ = AV_SAMPLE_FMT_NONE; + int input_sample_rate_; + uint64_t input_channel_layout_; // Opaque value used by resampler library. + + std::vector<float> resampled_audio_; +}; + +// Emits the primary video stream from a file. +class SimulatedVideoCapturer final : public SimulatedCapturer { + public: + class Client : public SimulatedCapturer::Observer { + public: + // Called to deliver the next video |frame|, which is always in I420 format. + // |capture_time| is used to synchronize the play-out of the video frame + // with respect to the audio track. + virtual void OnVideoFrame(const AVFrame& frame, + Clock::time_point capture_time) = 0; + + protected: + ~Client() override; + }; + + SimulatedVideoCapturer(Environment* environment, + const char* path, + Clock::time_point start_time, + Client* client); + + ~SimulatedVideoCapturer() final; + + private: + Client* const client_; + + // Sets up the decoder to produce I420 format output. + void SetAdditionalDecoderParameters(AVCodecContext* decoder_context) final; + + // Called at the moment Client::OnVideoFrame() should be called to provide the + // next video frame. + void DeliverDataToClient(const AVFrame& decoded_frame, + Clock::time_point capture_time) final; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc new file mode 100644 index 00000000000..07aeaaefd13 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc @@ -0,0 +1,224 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/streaming_opus_encoder.h" + +#include <opus/opus.h> + +#include <algorithm> +#include <chrono> // NOLINT + +namespace openscreen { +namespace cast { + +using std::chrono::duration_cast; +using std::chrono::microseconds; +using std::chrono::seconds; + +using openscreen::operator<<; // To pretty-print chrono values. + +namespace { + +// The bitrate at which virtually all stereo audio can be encoded and decoded +// without human-perceivable artifacts. Source: +// https://wiki.hydrogenaud.io/index.php?title=Opus#Bitrate_performance +constexpr opus_int32 kTransparentBitrate = 160000; + +// The maximum number of Cast audio frames the encoder may fall behind by before +// skipping-ahead the RTP timestamps to compensate. +constexpr int kMaxCastFramesBeforeSkip = 3; + +} // namespace + +StreamingOpusEncoder::StreamingOpusEncoder(int num_channels, + int cast_frames_per_second, + Sender* sender) + : num_channels_(num_channels), + sender_(sender), + samples_per_cast_frame_(sample_rate() / cast_frames_per_second), + approximate_cast_frame_duration_( + duration_cast<Clock::duration>(seconds(1)) / cast_frames_per_second), + encoder_storage_(new uint8_t[opus_encoder_get_size(num_channels_)]), + input_(new float[num_channels_ * samples_per_cast_frame_]), + output_(new uint8_t[kOpusMaxPayloadSize]) { + OSP_CHECK_GT(cast_frames_per_second, 0); + OSP_DCHECK(sender_); + OSP_CHECK_GT(samples_per_cast_frame_, 0); + OSP_CHECK_EQ(sample_rate() % cast_frames_per_second, 0); + OSP_CHECK(approximate_cast_frame_duration_ > Clock::duration::zero()); + + frame_.dependency = EncodedFrame::KEY_FRAME; + + const auto init_result = opus_encoder_init( + encoder(), sample_rate(), num_channels_, OPUS_APPLICATION_AUDIO); + OSP_CHECK_EQ(init_result, OPUS_OK); + + UseStandardQuality(); +} + +StreamingOpusEncoder::~StreamingOpusEncoder() = default; + +int StreamingOpusEncoder::GetBitrate() const { + opus_int32 bitrate; + const auto ctl_result = + opus_encoder_ctl(encoder(), OPUS_GET_BITRATE(&bitrate)); + OSP_CHECK_EQ(ctl_result, OPUS_OK); + return bitrate; +} + +void StreamingOpusEncoder::UseStandardQuality() { + const auto ctl_result = + opus_encoder_ctl(encoder(), OPUS_SET_BITRATE(OPUS_AUTO)); + OSP_CHECK_EQ(ctl_result, OPUS_OK); + UpdateCodecDelay(); +} + +void StreamingOpusEncoder::UseHighQuality() { + // kTransparentBitrate assumes stereo audio. Scale it by the actual number of + // channels. + const opus_int32 bitrate = kTransparentBitrate * num_channels_ / 2; + const auto ctl_result = + opus_encoder_ctl(encoder(), OPUS_SET_BITRATE(bitrate)); + OSP_CHECK_EQ(ctl_result, OPUS_OK); + UpdateCodecDelay(); +} + +void StreamingOpusEncoder::EncodeAndSend(const float* interleaved_samples, + int num_samples, + Clock::time_point reference_time) { + OSP_DCHECK(interleaved_samples); + OSP_DCHECK_GT(num_samples, 0); + + ResolveTimestampsAndMaybeSkip(reference_time); + + while (num_samples > 0) { + const int samples_copied = + FillInputBuffer(interleaved_samples, num_samples); + num_samples -= samples_copied; + interleaved_samples += num_channels_ * samples_copied; + + if (num_samples_queued_ < samples_per_cast_frame_) { + return; // Not enough yet for a full Cast audio frame. + } + + const opus_int32 packet_size_or_error = + opus_encode_float(encoder(), input_.get(), num_samples_queued_, + output_.get(), kOpusMaxPayloadSize); + num_samples_queued_ = 0; + if (packet_size_or_error < 0) { + OSP_LOG_FATAL << "AUDIO[" << sender_->ssrc() + << "] Error code from opus_encode_float(): " + << packet_size_or_error; + return; + } + + frame_.frame_id = sender_->GetNextFrameId(); + frame_.referenced_frame_id = frame_.frame_id; + // Note: It's possible for Opus to encode a zero byte packet. Send a Cast + // audio frame anyway, to represent the passage of silence and to send other + // stream metadata. + frame_.data = absl::Span<uint8_t>(output_.get(), packet_size_or_error); + last_sent_frame_reference_time_ = frame_.reference_time; + switch (sender_->EnqueueFrame(frame_)) { + case Sender::OK: + break; + case Sender::PAYLOAD_TOO_LARGE: + OSP_NOTREACHED(); // The Opus packet cannot possibly be too large. + break; + case Sender::REACHED_ID_SPAN_LIMIT: + OSP_LOG_WARN << "AUDIO[" << sender_->ssrc() + << "] Dropping: FrameId span limit reached."; + break; + case Sender::MAX_DURATION_IN_FLIGHT: + OSP_LOG_INFO << "AUDIO[" << sender_->ssrc() + << "] Dropping: In-flight duration would be too high."; + break; + } + + frame_.rtp_timestamp += RtpTimeDelta::FromTicks(samples_per_cast_frame_); + frame_.reference_time += approximate_cast_frame_duration_; + } +} + +void StreamingOpusEncoder::UpdateCodecDelay() { + opus_int32 lookahead = 0; + const auto ctl_result = + opus_encoder_ctl(encoder(), OPUS_GET_LOOKAHEAD(&lookahead)); + OSP_CHECK_EQ(ctl_result, OPUS_OK); + codec_delay_ = RtpTimeDelta::FromTicks(lookahead).ToDuration<Clock::duration>( + sample_rate()); +} + +void StreamingOpusEncoder::ResolveTimestampsAndMaybeSkip( + Clock::time_point reference_time) { + // Back-track the reference time to account for the audio delay introduced by + // the codec. + reference_time -= codec_delay_; + + // Special case: Nothing special for the first frame's timestamps. + if (start_time_ == Clock::time_point::min()) { + frame_.rtp_timestamp = RtpTimeTicks(); + frame_.reference_time = start_time_ = reference_time; + last_sent_frame_reference_time_ = + reference_time - approximate_cast_frame_duration_; + return; + } + + const RtpTimeTicks current_position = + frame_.rtp_timestamp + RtpTimeDelta::FromTicks(num_samples_queued_); + const RtpTimeTicks reference_position = RtpTimeTicks::FromTimeSinceOrigin( + reference_time - start_time_, sample_rate()); + const RtpTimeDelta rtp_advancement = reference_position - current_position; + const RtpTimeDelta skip_threshold = + RtpTimeDelta::FromTicks(samples_per_cast_frame_) * + kMaxCastFramesBeforeSkip; + if (rtp_advancement > skip_threshold) { + OSP_LOG_WARN << "Detected audio gap " + << rtp_advancement.ToDuration<microseconds>(sample_rate()) + << ", skipping ahead..."; + num_samples_queued_ = 0; + frame_.rtp_timestamp = reference_position; + } + + // Further back-track the reference time to account for the already-queued + // samples. + reference_time -= RtpTimeDelta::FromTicks(num_samples_queued_) + .ToDuration<Clock::duration>(sample_rate()); + + // Frame reference times must be monotonically increasing. A little noise in + // the negative direction is okay to cap-off. Log a warning if there's a + // bigger problem (at the source). + const Clock::time_point lower_bound = + last_sent_frame_reference_time_ + + RtpTimeDelta::FromTicks(1).ToDuration<Clock::duration>(sample_rate()); + if (reference_time < lower_bound) { + const Clock::duration backwards_amount = + last_sent_frame_reference_time_ - reference_time; + OSP_LOG_IF(WARN, backwards_amount >= approximate_cast_frame_duration_) + << "Reference time went *backwards* too much (" << backwards_amount + << " in wrong direction). A/V sync may suffer at the Receiver!"; + reference_time = lower_bound; + } + + frame_.reference_time = reference_time; +} + +int StreamingOpusEncoder::FillInputBuffer(const float* interleaved_samples, + int num_samples) { + const int samples_needed = samples_per_cast_frame_ - num_samples_queued_; + const int samples_to_copy = std::min(num_samples, samples_needed); + std::copy(interleaved_samples, + interleaved_samples + num_channels_ * samples_to_copy, + input_.get() + num_channels_ * num_samples_queued_); + num_samples_queued_ += samples_to_copy; + return samples_to_copy; +} + +// static +constexpr int StreamingOpusEncoder::kDefaultCastAudioFramesPerSecond; +// static +constexpr int StreamingOpusEncoder::kOpusMaxPayloadSize; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h new file mode 100644 index 00000000000..0620e3a4588 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h @@ -0,0 +1,123 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_ +#define CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_ + +#include <stdint.h> + +#include <memory> + +#include "cast/streaming/encoded_frame.h" +#include "cast/streaming/sender.h" +#include "platform/api/time.h" + +extern "C" { +struct OpusEncoder; +} + +namespace openscreen { +namespace cast { + +// Wraps the libopus encoder so that the application can stream +// interleaved-floats audio samples to a Sender. Either mono or stereo sound is +// supported. +class StreamingOpusEncoder { + public: + // Constructs the encoder for mono or stereo sound, dividing the stream of + // audio samples up into chunks as determined by the given + // |cast_frames_per_second|, and for EncodedFrame output to the given + // |sender|. The sample rate of the audio is assumed to be the Sender's fixed + // |rtp_timebase()|. + StreamingOpusEncoder(int num_channels, + int cast_frames_per_second, + Sender* sender); + + ~StreamingOpusEncoder(); + + int num_channels() const { return num_channels_; } + int sample_rate() const { return sender_->rtp_timebase(); } + + int GetBitrate() const; + + // Sets the encoder back to its "AUTO" bitrate setting, for standard quality. + // This and UseHighQuality() may be called as often as needed as conditions + // change. + // + // Note: As of 2020-01-21, the encoder in "auto bitrate" mode would use a + // bitrate of 102kbps for 2-channel, 48 kHz audio and a 10 ms frame size. + void UseStandardQuality(); + + // Sets the encoder to use a high bitrate (virtually no artifacts), when + // plenty of network bandwidth is available. This and UseStandardQuality() may + // be called as often as needed as conditions change. + void UseHighQuality(); + + // Encode and send the given |interleaved_samples|, which contains + // |num_samples| tuples (i.e., multiply by the number of channels to determine + // the number of array elements). The audio is assumed to have been captured + // at the required |sample_rate()|. |reference_time| refers to the first + // sample. + void EncodeAndSend(const float* interleaved_samples, + int num_samples, + Clock::time_point reference_time); + + static constexpr int kDefaultCastAudioFramesPerSecond = + 100; // 10 ms frame duration. + + private: + OpusEncoder* encoder() const { + return reinterpret_cast<OpusEncoder*>(encoder_storage_.get()); + } + + // Updates the |codec_delay_| based on the current encoder settings. + void UpdateCodecDelay(); + + // Sets the next frame's reference time, accounting for codec buffering delay. + // Also, checks whether the reference time has drifted too far forwards, and + // skips if necessary. + void ResolveTimestampsAndMaybeSkip(Clock::time_point reference_time); + + // Fills the input buffer as much as possible from the given source data, and + // returns the number of samples copied into the buffer. + int FillInputBuffer(const float* interleaved_samples, int num_samples); + + const int num_channels_; + Sender* const sender_; + const int samples_per_cast_frame_; + const Clock::duration approximate_cast_frame_duration_; + const std::unique_ptr<uint8_t[]> encoder_storage_; + const std::unique_ptr<float[]> input_; // Interleaved audio samples. + const std::unique_ptr<uint8_t[]> output_; // Opus-encoded packet. + + // The audio delay introduced by the codec. + Clock::duration codec_delay_{}; + + // The number of mono/stereo tuples currently queued in the |input_| buffer. + // Multiply by |num_channels_| to get the number of array elements. + int num_samples_queued_ = 0; + + // The reference time of the first frame passed to EncodeAndSend(), offset by + // the codec delay. + Clock::time_point start_time_ = Clock::time_point::min(); + + // Initialized and used by EncodeAndSend() to hold the metadata and data + // pointer for each frame being sent. + EncodedFrame frame_; + + // The |reference_time| for the last sent frame. This is used to check that + // the reference times are monotonically increasing. If they have [illegally] + // gone backwards too much, warnings will be logged. + Clock::time_point last_sent_frame_reference_time_; + + // This is the recommended value, according to documentation in + // src/include/opus.h in libopus, so that the Opus encoder does not degrade + // the audio due to memory constraints. + static constexpr int kOpusMaxPayloadSize = 4000; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc new file mode 100644 index 00000000000..178564d656d --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc @@ -0,0 +1,492 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/streaming_vp8_encoder.h" + +#include <stdint.h> +#include <string.h> +#include <vpx/vp8cx.h> + +#include <cmath> +#include <utility> + +#include "cast/streaming/encoded_frame.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/sender.h" +#include "util/logging.h" +#include "util/saturate_cast.h" + +namespace openscreen { +namespace cast { + +using std::chrono::duration_cast; +using std::chrono::milliseconds; +using std::chrono::seconds; + +// TODO(https://crbug.com/openscreen/123): Fix the declarations and then remove +// this: +using openscreen::operator<<; // For std::chrono::duration pretty-printing. + +namespace { + +constexpr int kBytesPerKilobyte = 1024; + +// Lower and upper bounds to the frame duration passed to vpx_codec_encode(), to +// ensure sanity. Note that the upper-bound is especially important in cases +// where the video paused for some lengthy amount of time. +constexpr Clock::duration kMinFrameDuration = milliseconds(1); +constexpr Clock::duration kMaxFrameDuration = milliseconds(125); + +// Highest/lowest allowed encoding speed set to the encoder. The valid range is +// [4, 16], but experiments show that with speed higher than 12, the saving of +// the encoding time is not worth the dropping of the quality. And, with speed +// lower than 6, the increasing amount of quality is not worth the increasing +// amount of encoding time. +constexpr int kHighestEncodingSpeed = 12; +constexpr int kLowestEncodingSpeed = 6; + +// This is the equivalent change in encoding speed per one quantizer step. +constexpr double kEquivalentEncodingSpeedStepPerQuantizerStep = 1 / 20.0; + +} // namespace + +StreamingVp8Encoder::StreamingVp8Encoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender) + : params_(params), + main_task_runner_(task_runner), + sender_(sender), + ideal_speed_setting_(kHighestEncodingSpeed), + encode_thread_([this] { ProcessWorkUnitsUntilTimeToQuit(); }) { + OSP_DCHECK_LE(1, params_.num_encode_threads); + OSP_DCHECK_LE(kMinQuantizer, params_.min_quantizer); + OSP_DCHECK_LE(params_.min_quantizer, params_.max_cpu_saver_quantizer); + OSP_DCHECK_LE(params_.max_cpu_saver_quantizer, params_.max_quantizer); + OSP_DCHECK_LE(params_.max_quantizer, kMaxQuantizer); + OSP_DCHECK_LT(0.0, params_.max_time_utilization); + OSP_DCHECK_LE(params_.max_time_utilization, 1.0); + OSP_DCHECK(main_task_runner_); + OSP_DCHECK(sender_); + + const auto result = + vpx_codec_enc_config_default(vpx_codec_vp8_cx(), &config_, 0); + OSP_CHECK_EQ(result, VPX_CODEC_OK); + + // This is set to non-zero in ConfigureForNewFrameSize() later, to flag that + // the encoder has been initialized. + config_.g_threads = 0; + + // Set the timebase to match that of openscreen::Clock::duration. + config_.g_timebase.num = Clock::duration::period::num; + config_.g_timebase.den = Clock::duration::period::den; + + // |g_pass| and |g_lag_in_frames| must be "one pass" and zero, respectively, + // because of the way the libvpx API is used. + config_.g_pass = VPX_RC_ONE_PASS; + config_.g_lag_in_frames = 0; + + // Rate control settings. + config_.rc_dropframe_thresh = 0; // The encoder may not drop any frames. + config_.rc_resize_allowed = 0; + config_.rc_end_usage = VPX_CBR; + config_.rc_target_bitrate = target_bitrate_ / kBytesPerKilobyte; + config_.rc_min_quantizer = params_.min_quantizer; + config_.rc_max_quantizer = params_.max_quantizer; + + // The reasons for the values chosen here (rc_*shoot_pct and rc_buf_*_sz) are + // lost in history. They were brought-over from the legacy Chrome Cast + // Streaming Sender implemenation. + config_.rc_undershoot_pct = 100; + config_.rc_overshoot_pct = 15; + config_.rc_buf_initial_sz = 500; + config_.rc_buf_optimal_sz = 600; + config_.rc_buf_sz = 1000; + + config_.kf_mode = VPX_KF_DISABLED; +} + +StreamingVp8Encoder::~StreamingVp8Encoder() { + { + std::unique_lock<std::mutex> lock(mutex_); + target_bitrate_ = 0; + cv_.notify_one(); + } + encode_thread_.join(); +} + +int StreamingVp8Encoder::GetTargetBitrate() const { + // Note: No need to lock the |mutex_| since this method should be called on + // the same thread as SetTargetBitrate(). + return target_bitrate_; +} + +void StreamingVp8Encoder::SetTargetBitrate(int new_bitrate) { + // Ensure that, when bps is converted to kbps downstream, that the encoder + // bitrate will not be zero. + new_bitrate = std::max(new_bitrate, kBytesPerKilobyte); + + std::unique_lock<std::mutex> lock(mutex_); + // Only assign the new target bitrate if |target_bitrate_| has not yet been + // used to signal the |encode_thread_| to end. + if (target_bitrate_ > 0) { + target_bitrate_ = new_bitrate; + } +} + +void StreamingVp8Encoder::EncodeAndSend( + const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback) { + WorkUnit work_unit; + + // TODO(miu): The |VideoFrame| struct should provide the media timestamp, + // instead of this code inferring it from the reference timestamps, since: 1) + // the video capturer's clock may tick at a different rate than the system + // clock; and 2) to reduce jitter. + if (start_time_ == Clock::time_point::min()) { + start_time_ = reference_time; + work_unit.rtp_timestamp = RtpTimeTicks(); + } else { + work_unit.rtp_timestamp = RtpTimeTicks::FromTimeSinceOrigin( + reference_time - start_time_, sender_->rtp_timebase()); + if (work_unit.rtp_timestamp <= last_enqueued_rtp_timestamp_) { + OSP_LOG_WARN << "VIDEO[" << sender_->ssrc() + << "] Dropping: RTP timestamp is not monotonically " + "increasing from last frame."; + return; + } + } + if (sender_->GetInFlightMediaDuration(work_unit.rtp_timestamp) > + sender_->GetMaxInFlightMediaDuration()) { + OSP_LOG_WARN << "VIDEO[" << sender_->ssrc() + << "] Dropping: In-flight media duration would be too high."; + return; + } + + Clock::duration frame_duration = frame.duration; + if (frame_duration <= Clock::duration::zero()) { + // The caller did not provide the frame duration in |frame|. + if (reference_time == start_time_) { + // Use the max for the first frame so libvpx will spend extra effort on + // its quality. + frame_duration = kMaxFrameDuration; + } else { + // Use the actual amount of time between the current and previous frame as + // a prediction for the next frame's duration. + frame_duration = + (work_unit.rtp_timestamp - last_enqueued_rtp_timestamp_) + .ToDuration<Clock::duration>(sender_->rtp_timebase()); + } + } + work_unit.duration = + std::max(std::min(frame_duration, kMaxFrameDuration), kMinFrameDuration); + + last_enqueued_rtp_timestamp_ = work_unit.rtp_timestamp; + + work_unit.image = CloneAsVpxImage(frame); + work_unit.reference_time = reference_time; + work_unit.stats_callback = std::move(stats_callback); + const bool force_key_frame = sender_->NeedsKeyFrame(); + { + std::unique_lock<std::mutex> lock(mutex_); + needs_key_frame_ |= force_key_frame; + encode_queue_.push(std::move(work_unit)); + cv_.notify_one(); + } +} + +void StreamingVp8Encoder::DestroyEncoder() { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + if (is_encoder_initialized()) { + vpx_codec_destroy(&encoder_); + // Flag that the encoder is not initialized. See header comments for + // is_encoder_initialized(). + config_.g_threads = 0; + } +} + +void StreamingVp8Encoder::ProcessWorkUnitsUntilTimeToQuit() { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + for (;;) { + WorkUnitWithResults work_unit{}; + bool force_key_frame; + int target_bitrate; + { + std::unique_lock<std::mutex> lock(mutex_); + if (target_bitrate_ <= 0) { + break; // Time to end this thread. + } + if (encode_queue_.empty()) { + cv_.wait(lock); + if (encode_queue_.empty()) { + continue; + } + } + static_cast<WorkUnit&>(work_unit) = std::move(encode_queue_.front()); + encode_queue_.pop(); + force_key_frame = needs_key_frame_; + target_bitrate = target_bitrate_; + } + + // Clock::now() is being called directly, instead of using a + // dependency-injected "now function," since actual wall time is being + // measured. + const Clock::time_point encode_start_time = Clock::now(); + PrepareEncoder(work_unit.image->d_w, work_unit.image->d_h, target_bitrate); + EncodeFrame(force_key_frame, &work_unit); + ComputeFrameEncodeStats(Clock::now() - encode_start_time, target_bitrate, + &work_unit); + UpdateSpeedSettingForNextFrame(work_unit.stats); + + main_task_runner_->PostTask( + [this, results = std::move(work_unit)]() mutable { + SendEncodedFrame(std::move(results)); + }); + } + + DestroyEncoder(); +} + +void StreamingVp8Encoder::PrepareEncoder(int width, + int height, + int target_bitrate) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + const int target_kbps = target_bitrate / kBytesPerKilobyte; + + // Translate the |ideal_speed_setting_| into the VP8E_SET_CPUUSED setting and + // the minimum quantizer to use. + int speed; + int min_quantizer; + if (ideal_speed_setting_ > kHighestEncodingSpeed) { + speed = kHighestEncodingSpeed; + const double remainder = ideal_speed_setting_ - speed; + min_quantizer = rounded_saturate_cast<int>( + remainder / kEquivalentEncodingSpeedStepPerQuantizerStep + + params_.min_quantizer); + min_quantizer = std::min(min_quantizer, params_.max_cpu_saver_quantizer); + } else { + speed = std::max(rounded_saturate_cast<int>(ideal_speed_setting_), + kLowestEncodingSpeed); + min_quantizer = params_.min_quantizer; + } + + if (static_cast<int>(config_.g_w) != width || + static_cast<int>(config_.g_h) != height) { + DestroyEncoder(); + } + + if (!is_encoder_initialized()) { + config_.g_threads = params_.num_encode_threads; + config_.g_w = width; + config_.g_h = height; + config_.rc_target_bitrate = target_kbps; + config_.rc_min_quantizer = min_quantizer; + + encoder_ = {}; + const vpx_codec_flags_t flags = 0; + const auto init_result = + vpx_codec_enc_init(&encoder_, vpx_codec_vp8_cx(), &config_, flags); + OSP_CHECK_EQ(init_result, VPX_CODEC_OK); + + // Raise the threshold for considering macroblocks as static. The default is + // zero, so this setting makes the encoder less sensitive to motion. This + // lowers the probability of needing to utilize more CPU to search for + // motion vectors. + const auto ctl_result = + vpx_codec_control(&encoder_, VP8E_SET_STATIC_THRESHOLD, 1); + OSP_CHECK_EQ(ctl_result, VPX_CODEC_OK); + + // Ensure the speed will be set (below). + current_speed_setting_ = ~speed; + } else if (static_cast<int>(config_.rc_target_bitrate) != target_kbps || + static_cast<int>(config_.rc_min_quantizer) != min_quantizer) { + config_.rc_target_bitrate = target_kbps; + config_.rc_min_quantizer = min_quantizer; + const auto update_config_result = + vpx_codec_enc_config_set(&encoder_, &config_); + OSP_CHECK_EQ(update_config_result, VPX_CODEC_OK); + } + + if (current_speed_setting_ != speed) { + // Pass the |speed| as a negative value to turn off VP8's automatic speed + // selection logic and force the exact setting. + const auto ctl_result = + vpx_codec_control(&encoder_, VP8E_SET_CPUUSED, -speed); + OSP_CHECK_EQ(ctl_result, VPX_CODEC_OK); + current_speed_setting_ = speed; + } +} + +void StreamingVp8Encoder::EncodeFrame(bool force_key_frame, + WorkUnitWithResults* work_unit) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + // The presentation timestamp argument here is fixed to zero to force the + // encoder to base its single-frame bandwidth calculations entirely on + // |frame_duration| and the target bitrate setting. + const vpx_codec_pts_t pts = 0; + const vpx_enc_frame_flags_t flags = force_key_frame ? VPX_EFLAG_FORCE_KF : 0; + const auto encode_result = + vpx_codec_encode(&encoder_, work_unit->image.get(), pts, + work_unit->duration.count(), flags, VPX_DL_REALTIME); + OSP_CHECK_EQ(encode_result, VPX_CODEC_OK); + + const vpx_codec_cx_pkt_t* pkt; + for (vpx_codec_iter_t iter = nullptr;;) { + pkt = vpx_codec_get_cx_data(&encoder_, &iter); + // vpx_codec_get_cx_data() returns null once the "iteration" is complete. + // However, that point should never be reached because a + // VPX_CODEC_CX_FRAME_PKT must be encountered before that. + OSP_CHECK(pkt); + if (pkt->kind == VPX_CODEC_CX_FRAME_PKT) { + break; + } + } + + // A copy of the payload data is being made here. That's okay since it has to + // be copied at some point anyway, to be passed back to the main thread. + auto* const begin = static_cast<const uint8_t*>(pkt->data.frame.buf); + auto* const end = begin + pkt->data.frame.sz; + work_unit->payload.assign(begin, end); + work_unit->is_key_frame = !!(pkt->data.frame.flags & VPX_FRAME_IS_KEY); +} + +void StreamingVp8Encoder::ComputeFrameEncodeStats( + Clock::duration encode_wall_time, + int target_bitrate, + WorkUnitWithResults* work_unit) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + Stats& stats = work_unit->stats; + + // Note: stats.frame_id is set later, in SendEncodedFrame(). + stats.rtp_timestamp = work_unit->rtp_timestamp; + stats.encode_wall_time = encode_wall_time; + stats.frame_duration = work_unit->duration; + stats.encoded_size = work_unit->payload.size(); + + constexpr double kBytesPerBit = 1.0 / CHAR_BIT; + constexpr double kSecondsPerClockTick = + 1.0 / duration_cast<Clock::duration>(seconds(1)).count(); + const double target_bytes_per_clock_tick = + target_bitrate * (kBytesPerBit * kSecondsPerClockTick); + stats.target_size = target_bytes_per_clock_tick * work_unit->duration.count(); + + // The quantizer the encoder used. This is the result of the VP8 encoder + // taking a guess at what quantizer value would produce an encoded frame size + // as close to the target as possible. + const auto get_quantizer_result = vpx_codec_control( + &encoder_, VP8E_GET_LAST_QUANTIZER_64, &stats.quantizer); + OSP_CHECK_EQ(get_quantizer_result, VPX_CODEC_OK); + + // Now that the frame has been encoded and the number of bytes is known, the + // perfect quantizer value (i.e., the one that should have been used) can be + // determined. + stats.perfect_quantizer = stats.quantizer * stats.space_utilization(); +} + +void StreamingVp8Encoder::UpdateSpeedSettingForNextFrame(const Stats& stats) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + // Combine the speed setting that was used to encode the last frame, and the + // quantizer the encoder chose into a single speed metric. + const double speed = current_speed_setting_ + + kEquivalentEncodingSpeedStepPerQuantizerStep * + std::max(0, stats.quantizer - params_.min_quantizer); + + // Like |Stats::perfect_quantizer|, this computes a "hindsight" speed setting + // for the last frame, one that may have potentially allowed for a + // better-quality quantizer choice by the encoder, while also keeping CPU + // utilization within budget. + const double perfect_speed = + speed * stats.time_utilization() / params_.max_time_utilization; + + // Update the ideal speed setting, to be used for the next frame. An + // exponentially-decaying weighted average is used here to smooth-out noise. + // The weight is based on the duration of the frame that was encoded. + constexpr Clock::duration kDecayHalfLife = milliseconds(120); + const double ticks = stats.frame_duration.count(); + const double weight = ticks / (ticks + kDecayHalfLife.count()); + ideal_speed_setting_ = + weight * perfect_speed + (1.0 - weight) * ideal_speed_setting_; + OSP_DCHECK(std::isfinite(ideal_speed_setting_)); +} + +void StreamingVp8Encoder::SendEncodedFrame(WorkUnitWithResults results) { + OSP_DCHECK(main_task_runner_->IsRunningOnTaskRunner()); + + EncodedFrame frame; + frame.frame_id = sender_->GetNextFrameId(); + if (results.is_key_frame) { + frame.dependency = EncodedFrame::KEY_FRAME; + frame.referenced_frame_id = frame.frame_id; + } else { + frame.dependency = EncodedFrame::DEPENDS_ON_ANOTHER; + frame.referenced_frame_id = frame.frame_id - 1; + } + frame.rtp_timestamp = results.rtp_timestamp; + frame.reference_time = results.reference_time; + frame.data = absl::Span<uint8_t>(results.payload); + + if (sender_->EnqueueFrame(frame) != Sender::OK) { + // Since the frame will not be sent, the encoder's frame dependency chain + // has been broken. Force a key frame for the next frame. + std::unique_lock<std::mutex> lock(mutex_); + needs_key_frame_ = true; + } + + if (results.stats_callback) { + results.stats.frame_id = frame.frame_id; + results.stats_callback(results.stats); + } +} + +namespace { +void CopyPlane(const uint8_t* src, + int src_stride, + int num_rows, + uint8_t* dst, + int dst_stride) { + if (src_stride == dst_stride) { + memcpy(dst, src, src_stride * num_rows); + return; + } + const int bytes_per_row = std::min(src_stride, dst_stride); + while (--num_rows >= 0) { + memcpy(dst, src, bytes_per_row); + dst += dst_stride; + src += src_stride; + } +} +} // namespace + +// static +StreamingVp8Encoder::VpxImageUniquePtr StreamingVp8Encoder::CloneAsVpxImage( + const VideoFrame& frame) { + OSP_DCHECK_GE(frame.width, 0); + OSP_DCHECK_GE(frame.height, 0); + OSP_DCHECK_GE(frame.yuv_strides[0], 0); + OSP_DCHECK_GE(frame.yuv_strides[1], 0); + OSP_DCHECK_GE(frame.yuv_strides[2], 0); + + constexpr int kAlignment = 32; + VpxImageUniquePtr image(vpx_img_alloc(nullptr, VPX_IMG_FMT_I420, frame.width, + frame.height, kAlignment)); + OSP_CHECK(image); + + CopyPlane(frame.yuv_planes[0], frame.yuv_strides[0], frame.height, + image->planes[VPX_PLANE_Y], image->stride[VPX_PLANE_Y]); + CopyPlane(frame.yuv_planes[1], frame.yuv_strides[1], (frame.height + 1) / 2, + image->planes[VPX_PLANE_U], image->stride[VPX_PLANE_U]); + CopyPlane(frame.yuv_planes[2], frame.yuv_strides[2], (frame.height + 1) / 2, + image->planes[VPX_PLANE_V], image->stride[VPX_PLANE_V]); + + return image; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h new file mode 100644 index 00000000000..1c64cafc5b0 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h @@ -0,0 +1,302 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ +#define CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ + +#include <vpx/vpx_encoder.h> +#include <vpx/vpx_image.h> + +#include <algorithm> +#include <condition_variable> // NOLINT +#include <functional> +#include <memory> +#include <mutex> // NOLINT +#include <queue> +#include <thread> // NOLINT +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/rtp_time.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" + +namespace openscreen { + +class TaskRunner; + +namespace cast { + +class Sender; + +// Uses libvpx to encode VP8 video and streams it to a Sender. Includes +// extensive logic for fine-tuning the encoder parameters in real-time, to +// provide the best quality results given external, uncontrollable factors: +// CPU/network availability, and the complexity of the video frame content. +// +// Internally, a separate encode thread is created and used to prevent blocking +// the main thread while frames are being encoded. All public API methods are +// assumed to be called on the same sequence/thread as the main TaskRunner +// (injected via the constructor). +// +// Usage: +// +// 1. EncodeAndSend() is used to queue-up video frames for encoding and sending, +// which will be done on a best-effort basis. +// +// 2. The client is expected to call SetTargetBitrate() frequently based on its +// own bandwidth estimates and congestion control logic. In addition, a client +// may provide a callback for each frame's encode statistics, which can be used +// to further optimize the user experience. For example, the stats can be used +// as a signal to reduce the data volume (i.e., resolution and/or frame rate) +// coming from the video capture source. +class StreamingVp8Encoder { + public: + // Configurable parameters passed to the StreamingVp8Encoder constructor. + struct Parameters { + // Number of threads to parallelize frame encoding. This should be set based + // on the number of CPU cores available for encoding, but no more than 8. + int num_encode_threads = + std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8); + + // Best-quality quantizer (lower is better quality). Range: [0,63] + int min_quantizer = 4; + + // Worst-quality quantizer (lower is better quality). Range: [0,63] + int max_quantizer = 63; + + // Worst-quality quantizer to use when the CPU is extremely constrained. + // Range: [min_quantizer,max_quantizer] + int max_cpu_saver_quantizer = 25; + + // Maximum amount of wall-time a frame's encode can take, relative to the + // frame's duration, before the CPU-saver logic is activated. The default + // (70%) is appropriate for systems with four or more cores, but should be + // reduced (e.g., 50%) for systems with fewer than three cores. + // + // Example: For 30 FPS (continuous) video, the frame duration is ~33.3ms, + // and a value of 0.5 here would mean that the CPU-saver logic starts + // sacrificing quality when frame encodes start taking longer than ~16.7ms. + double max_time_utilization = 0.7; + }; + + // Represents an input VideoFrame, passed to EncodeAndSend(). + struct VideoFrame { + // Image width and height. + int width; + int height; + + // I420 format image pointers and row strides (the number of bytes between + // the start of successive rows). The pointers only need to remain valid + // until the EncodeAndSend() call returns. + const uint8_t* yuv_planes[3]; + int yuv_strides[3]; + + // How long this frame will be held before the next frame will be displayed, + // or zero if unknown. The frame duration is passed to the VP8 codec, + // affecting a number of important behaviors, including: per-frame + // bandwidth, CPU time spent encoding, temporal quality trade-offs, and + // key/golden/alt-ref frame generation intervals. + Clock::duration duration; + }; + + // Performance statistics for a single frame's encode. + // + // For full details on how to use these stats in an end-to-end system, see: + // https://www.chromium.org/developers/design-documents/ + // auto-throttled-screen-capture-and-mirroring + // and https://source.chromium.org/chromium/chromium/src/+/master: + // media/cast/sender/performance_metrics_overlay.h + struct Stats { + // The Cast Streaming ID that was assigned to the frame. + FrameId frame_id; + + // The RTP timestamp of the frame. + RtpTimeTicks rtp_timestamp; + + // How long the frame took to encode. This is wall time, not CPU time or + // some other load metric. + Clock::duration encode_wall_time; + + // The frame's predicted duration; or, the actual duration if it was + // provided in the VideoFrame. + Clock::duration frame_duration; + + // The encoded frame's size in bytes. + int encoded_size; + + // The average size of an encoded frame in bytes, having this + // |frame_duration| and current target bitrate. + double target_size; + + // The actual quantizer the VP8 encoder used, in the range [0,63]. + int quantizer; + + // The "hindsight" quantizer value that would have produced the best quality + // encoding of the frame at the current target bitrate. The nominal range is + // [0.0,63.0]. If it is larger than 63.0, then it was impossible for VP8 to + // encode the frame within the current target bitrate (e.g., too much + // "entropy" in the image, or too low a target bitrate). + double perfect_quantizer; + + // Utilization feedback metrics. The nominal range for each of these is + // [0.0,1.0] where 1.0 means "the entire budget available for the frame was + // exhausted." Going above 1.0 is okay for one or a few frames, since it's + // the average over many frames that matters before the system is considered + // "redlining." + // + // The max of these three provides an overall utilization control signal. + // The usual approach is for upstream control logic to increase/decrease the + // data volume (e.g., video resolution and/or frame rate) to maintain a good + // target point. + double time_utilization() const { + return static_cast<double>(encode_wall_time.count()) / + frame_duration.count(); + } + double space_utilization() const { return encoded_size / target_size; } + double entropy_utilization() const { + return perfect_quantizer / kMaxQuantizer; + } + }; + + StreamingVp8Encoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender); + + ~StreamingVp8Encoder(); + + // Get/Set the target bitrate. This may be changed at any time, as frequently + // as desired, and it will take effect internally as soon as possible. + int GetTargetBitrate() const; + void SetTargetBitrate(int new_bitrate); + + // Encode |frame| using the VP8 encoder, assemble an EncodedFrame, and enqueue + // into the Sender. The frame may be dropped if too many frames are in-flight. + // If provided, the |stats_callback| is run after the frame is enqueued in the + // Sender (via the main TaskRunner). + void EncodeAndSend(const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback); + + static constexpr int kMinQuantizer = 0; + static constexpr int kMaxQuantizer = 63; + + private: + // Syntactic convenience to wrap the vpx_image_t alloc/free API in a smart + // pointer. + struct VpxImageDeleter { + void operator()(vpx_image_t* ptr) const { vpx_img_free(ptr); } + }; + using VpxImageUniquePtr = std::unique_ptr<vpx_image_t, VpxImageDeleter>; + + // Represents the state of one frame encode. This is created in + // EncodeAndSend(), and passed to the encode thread via the |encode_queue_|. + struct WorkUnit { + VpxImageUniquePtr image; + Clock::duration duration; + Clock::time_point reference_time; + RtpTimeTicks rtp_timestamp; + std::function<void(Stats)> stats_callback; + }; + + // Same as WorkUnit, but with additional fields to carry the encode results. + struct WorkUnitWithResults : public WorkUnit { + std::vector<uint8_t> payload; + bool is_key_frame; + Stats stats; + }; + + bool is_encoder_initialized() const { return config_.g_threads != 0; } + + // Destroys the VP8 encoder context if it has been initialized. + void DestroyEncoder(); + + // The procedure for the |encode_thread_| that loops, processing work units + // from the |encode_queue_| by calling Encode() until it's time to end the + // thread. + void ProcessWorkUnitsUntilTimeToQuit(); + + // If the |encoder_| is live, attempt reconfiguration to allow it to encode + // frames at a new frame size, target bitrate, or "CPU encoding speed." If + // reconfiguration is not possible, destroy the existing instance and + // re-create a new |encoder_| instance. + void PrepareEncoder(int width, int height, int target_bitrate); + + // Wraps the complex libvpx vpx_codec_encode() call using inputs from + // |work_unit| and populating results there. + void EncodeFrame(bool force_key_frame, WorkUnitWithResults* work_unit); + + // Computes and populates |work_unit.stats| after the last call to + // EncodeFrame(). + void ComputeFrameEncodeStats(Clock::duration encode_wall_time, + int target_bitrate, + WorkUnitWithResults* work_unit); + + // Updates the |ideal_speed_setting_|, to take effect with the next frame + // encode, based on the given performance |stats|. + void UpdateSpeedSettingForNextFrame(const Stats& stats); + + // Assembles and enqueues an EncodedFrame with the Sender on the main thread. + void SendEncodedFrame(WorkUnitWithResults results); + + // Allocates a vpx_image_t and copies the content from |frame| to it. + static VpxImageUniquePtr CloneAsVpxImage(const VideoFrame& frame); + + const Parameters params_; + TaskRunner* const main_task_runner_; + Sender* const sender_; + + // The reference time of the first frame passed to EncodeAndSend(). + Clock::time_point start_time_ = Clock::time_point::min(); + + // The RTP timestamp of the last frame that was pushed into the + // |encode_queue_| by EncodeAndSend(). This is used to check whether + // timestamps are monotonically increasing. + RtpTimeTicks last_enqueued_rtp_timestamp_; + + // Guards a few members shared by both the main and encode threads. + std::mutex mutex_; + + // Used by the encode thread to sleep until more work is available. + std::condition_variable cv_ ABSL_GUARDED_BY(mutex_); + + // These encode parameters not passed in the WorkUnit struct because it is + // desirable for them to be applied as soon as possible, with the very next + // WorkUnit popped from the |encode_queue_| on the encode thread, and not to + // wait until some later WorkUnit is processed. + bool needs_key_frame_ ABSL_GUARDED_BY(mutex_) = true; + int target_bitrate_ ABSL_GUARDED_BY(mutex_) = 2 << 20; // Default: 2 Mbps. + + // The queue of frame encodes. The size of this queue is implicitly bounded by + // EncodeAndSend(), where it checks for the total in-flight media duration and + // maybe drops a frame. + std::queue<WorkUnit> encode_queue_ ABSL_GUARDED_BY(mutex_); + + // Current VP8 encoder configuration. Most of the fields are unchanging, and + // are populated in the ctor; but thereafter, only the encode thread accesses + // this struct. + // + // The speed setting is controlled via a separate libvpx API (see members + // below). + vpx_codec_enc_cfg_t config_{}; + + // These represent the magnitude of the VP8 speed setting, where larger values + // (i.e., faster speed) request less CPU usage but will provide lower video + // quality. Only the encode thread accesses these. + double ideal_speed_setting_; // A time-weighted average, from measurements. + int current_speed_setting_; // Current |encoder_| speed setting. + + // libvpx VP8 encoder instance. Only the encode thread accesses this. + vpx_codec_ctx_t encoder_; + + // This member should be last in the class since the thread should not start + // until all above members have been initialized by the constructor. + std::thread encode_thread_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn index d45ee925a78..688a453f1ce 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn @@ -3,6 +3,7 @@ # found in the LICENSE file. import("//build_overrides/build.gni") +import("../../testing/libfuzzer/fuzzer_test.gni") source_set("common") { sources = [ @@ -51,10 +52,14 @@ source_set("common") { source_set("receiver") { sources = [ + "answer_messages.cc", + "answer_messages.h", "compound_rtcp_builder.cc", "compound_rtcp_builder.h", "frame_collector.cc", "frame_collector.h", + "offer_messages.cc", + "offer_messages.h", "packet_receive_stats_tracker.cc", "packet_receive_stats_tracker.h", "receiver.cc", @@ -72,14 +77,24 @@ source_set("receiver") { public_deps = [ ":common", ] + + deps = [ + "../../util", + ] } source_set("sender") { sources = [ + "bandwidth_estimator.cc", + "bandwidth_estimator.h", "compound_rtcp_parser.cc", "compound_rtcp_parser.h", "rtp_packetizer.cc", "rtp_packetizer.h", + "sender.cc", + "sender.h", + "sender_packet_router.cc", + "sender_packet_router.h", "sender_report_builder.cc", "sender_report_builder.h", ] @@ -93,21 +108,29 @@ source_set("unittests") { testonly = true sources = [ + "answer_messages_unittest.cc", + "bandwidth_estimator_unittest.cc", "compound_rtcp_builder_unittest.cc", "compound_rtcp_parser_unittest.cc", "expanded_value_base_unittest.cc", "frame_collector_unittest.cc", "frame_crypto_unittest.cc", "mock_compound_rtcp_parser_client.h", + "mock_environment.cc", + "mock_environment.h", "ntp_time_unittest.cc", + "offer_messages_unittest.cc", "packet_receive_stats_tracker_unittest.cc", "packet_util_unittest.cc", + "receiver_session_unittest.cc", "receiver_unittest.cc", "rtcp_common_unittest.cc", "rtp_packet_parser_unittest.cc", "rtp_packetizer_unittest.cc", "rtp_time_unittest.cc", + "sender_packet_router_unittest.cc", "sender_report_unittest.cc", + "sender_unittest.cc", "ssrc_unittest.cc", ] @@ -116,82 +139,54 @@ source_set("unittests") { ":sender", "../../third_party/googletest:gmock", "../../third_party/googletest:gtest", + "../../util", + ] +} + +openscreen_fuzzer_test("compound_rtcp_parser_fuzzer") { + sources = [ + "compound_rtcp_parser_fuzzer.cc", + ] + + deps = [ + ":sender", + "../../third_party/abseil", + ] + + seed_corpus = "compound_rtcp_parser_fuzzer_seeds" + + # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. + libfuzzer_options = [ "max_len=1500" ] +} + +openscreen_fuzzer_test("rtp_packet_parser_fuzzer") { + sources = [ + "rtp_packet_parser_fuzzer.cc", ] + + deps = [ + ":receiver", + "../../third_party/abseil", + ] + + seed_corpus = "rtp_packet_parser_fuzzer_seeds" + + # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. + libfuzzer_options = [ "max_len=1500" ] } -if (build_with_chromium) { - import("//testing/libfuzzer/fuzzer_test.gni") - - fuzzer_test("compound_rtcp_parser_fuzzer") { - sources = [ - "compound_rtcp_parser_fuzzer.cc", - ] - - deps = [ - ":sender", - "../../third_party/abseil", - ] - - seed_corpus = "compound_rtcp_parser_fuzzer_seeds" - - # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. - libfuzzer_options = [ "max_len=1500" ] - } - - fuzzer_test("rtp_packet_parser_fuzzer") { - sources = [ - "rtp_packet_parser_fuzzer.cc", - ] - - deps = [ - ":receiver", - "../../third_party/abseil", - ] - - seed_corpus = "rtp_packet_parser_fuzzer_seeds" - - # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. - libfuzzer_options = [ "max_len=1500" ] - } - - fuzzer_test("sender_report_parser_fuzzer") { - sources = [ - "sender_report_parser_fuzzer.cc", - ] - - deps = [ - ":receiver", - "../../third_party/abseil", - ] - - seed_corpus = "sender_report_parser_fuzzer_seeds" - - # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. - libfuzzer_options = [ "max_len=1500" ] - } -} else { - # Note: The following is commented out because, as of this writing, the LLVM - # toolchain we pull does not include libclang_rt.fuzzer-x86_64.a, the - # libFuzzer library *with* a main() to drive everything. Thus, the only way to - # get things working is to specify an exact path to the fuzzer_no_main variant - # of the library that *is* avalable, and then provide our own main(). In - # summary, what you see below demonstrates how to get it working specifically - # for Clang 9.0.0 on Linux x86_64. One need only modify the "libs = [...]" for - # a different Clang, OS, or architecture. - # if (is_clang) { - # executable("rtp_packet_parser_fuzzer") { - # testonly = true - # defines = [ "NEEDS_MAIN_TO_CALL_FUZZER_DRIVER" ] - # sources = [ - # "rtp_packet_parser_fuzzer.cc", - # ] - # cflags_cc = [ "-fsanitize=address,fuzzer-no-link,undefined" ] - # ldflags = [ "-fsanitize=address,undefined" ] - # libs = [ "$clang_base_path/lib/clang/9.0.0/lib/linux/libclang_rt.fuzzer_no_main-x86_64.a" ] - # deps = [ - # ":receiver", - # "../../third_party/abseil", - # ] - # } - # } +openscreen_fuzzer_test("sender_report_parser_fuzzer") { + sources = [ + "sender_report_parser_fuzzer.cc", + ] + + deps = [ + ":receiver", + "../../third_party/abseil", + ] + + seed_corpus = "sender_report_parser_fuzzer_seeds" + + # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h. + libfuzzer_options = [ "max_len=1500" ] } diff --git a/chromium/third_party/openscreen/src/cast/streaming/DEPS b/chromium/third_party/openscreen/src/cast/streaming/DEPS index 03a77eeaa00..de4027bf0d5 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/DEPS +++ b/chromium/third_party/openscreen/src/cast/streaming/DEPS @@ -6,5 +6,6 @@ include_rules = [ '+cast/common', '+cast/receiver', '+cast/sender', - '+openssl' + '+openssl', + '+json', ] diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc new file mode 100644 index 00000000000..bd60a58c11f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc @@ -0,0 +1,208 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/answer_messages.h" + +#include <utility> + +#include "absl/strings/str_cat.h" +#include "cast/streaming/message_util.h" +#include "platform/base/error.h" +#include "util/logging.h" + +namespace openscreen { +namespace cast { + +namespace { + +static constexpr char kMessageKeyType[] = "type"; +static constexpr char kMessageTypeAnswer[] = "ANSWER"; + +// List of ANSWER message fields. +static constexpr char kAnswerMessageBody[] = "answer"; +static constexpr char kResult[] = "result"; +static constexpr char kResultOk[] = "ok"; +static constexpr char kResultError[] = "error"; +static constexpr char kErrorMessageBody[] = "error"; +static constexpr char kErrorCode[] = "code"; +static constexpr char kErrorDescription[] = "description"; + +Json::Value AspectRatioConstraintToJson(AspectRatioConstraint aspect_ratio) { + switch (aspect_ratio) { + case AspectRatioConstraint::kVariable: + return Json::Value("receiver"); + case AspectRatioConstraint::kFixed: + default: + return Json::Value("sender"); + } +} + +template <typename T> +Json::Value PrimitiveVectorToJson(const std::vector<T>& vec) { + Json::Value array(Json::ValueType::arrayValue); + array.resize(vec.size()); + + for (Json::Value::ArrayIndex i = 0; i < vec.size(); ++i) { + array[i] = Json::Value(vec[i]); + } + + return array; +} + +} // namespace + +ErrorOr<Json::Value> AudioConstraints::ToJson() const { + if (max_sample_rate <= 0 || max_channels <= 0 || min_bit_rate <= 0 || + max_bit_rate < min_bit_rate) { + return CreateParameterError("AudioConstraints"); + } + + Json::Value root; + root["maxSampleRate"] = max_sample_rate; + root["maxChannels"] = max_channels; + root["minBitRate"] = min_bit_rate; + root["maxBitRate"] = max_bit_rate; + root["maxDelay"] = Json::Value::Int64(max_delay.count()); + return root; +} + +ErrorOr<Json::Value> Dimensions::ToJson() const { + if (width <= 0 || height <= 0 || !frame_rate.is_defined() || + !frame_rate.is_positive()) { + return CreateParameterError("Dimensions"); + } + + Json::Value root; + root["width"] = width; + root["height"] = height; + root["frameRate"] = frame_rate.ToString(); + return root; +} + +ErrorOr<Json::Value> VideoConstraints::ToJson() const { + if (max_pixels_per_second <= 0 || min_bit_rate <= 0 || + max_bit_rate < min_bit_rate || max_delay.count() <= 0) { + return CreateParameterError("VideoConstraints"); + } + + auto error_or_min_dim = min_dimensions.ToJson(); + if (error_or_min_dim.is_error()) { + return error_or_min_dim.error(); + } + + auto error_or_max_dim = max_dimensions.ToJson(); + if (error_or_max_dim.is_error()) { + return error_or_max_dim.error(); + } + + Json::Value root; + root["maxPixelsPerSecond"] = max_pixels_per_second; + root["minDimensions"] = error_or_min_dim.value(); + root["maxDimensions"] = error_or_max_dim.value(); + root["minBitRate"] = min_bit_rate; + root["maxBitRate"] = max_bit_rate; + root["maxDelay"] = Json::Value::Int64(max_delay.count()); + return root; +} + +ErrorOr<Json::Value> Constraints::ToJson() const { + auto audio_or_error = audio.ToJson(); + if (audio_or_error.is_error()) { + return audio_or_error.error(); + } + + auto video_or_error = video.ToJson(); + if (video_or_error.is_error()) { + return video_or_error.error(); + } + + Json::Value root; + root["audio"] = audio_or_error.value(); + root["video"] = video_or_error.value(); + return root; +} + +ErrorOr<Json::Value> DisplayDescription::ToJson() const { + if (aspect_ratio.width < 1 || aspect_ratio.height < 1) { + return CreateParameterError("DisplayDescription"); + } + + auto dimensions_or_error = dimensions.ToJson(); + if (dimensions_or_error.is_error()) { + return dimensions_or_error.error(); + } + + Json::Value root; + root["dimensions"] = dimensions_or_error.value(); + root["aspectRatio"] = + absl::StrCat(aspect_ratio.width, ":", aspect_ratio.height); + root["scaling"] = AspectRatioConstraintToJson(aspect_ratio_constraint); + return root; +} + +ErrorOr<Json::Value> Answer::ToJson() const { + if (udp_port <= 0 || udp_port > 65535) { + return CreateParameterError("Answer - UDP Port number"); + } + + Json::Value root; + if (constraints) { + auto constraints_or_error = constraints.value().ToJson(); + if (constraints_or_error.is_error()) { + return constraints_or_error.error(); + } + root["constraints"] = constraints_or_error.value(); + } + + if (display) { + auto display_or_error = display.value().ToJson(); + if (display_or_error.is_error()) { + return display_or_error.error(); + } + root["display"] = display_or_error.value(); + } + + root["castMode"] = cast_mode.ToString(); + root["udpPort"] = udp_port; + root["receiverGetStatus"] = supports_wifi_status_reporting; + root["sendIndexes"] = PrimitiveVectorToJson(send_indexes); + root["ssrcs"] = PrimitiveVectorToJson(ssrcs); + if (!receiver_rtcp_event_log.empty()) { + root["receiverRtcpEventLog"] = + PrimitiveVectorToJson(receiver_rtcp_event_log); + } + if (!receiver_rtcp_dscp.empty()) { + root["receiverRtcpDscp"] = PrimitiveVectorToJson(receiver_rtcp_dscp); + } + if (!rtp_extensions.empty()) { + root["rtpExtensions"] = PrimitiveVectorToJson(rtp_extensions); + } + return root; +} + +Json::Value Answer::ToAnswerMessage() const { + auto json_or_error = ToJson(); + if (json_or_error.is_error()) { + return CreateInvalidAnswer(json_or_error.error()); + } + + Json::Value message_root; + message_root[kMessageKeyType] = kMessageTypeAnswer; + message_root[kAnswerMessageBody] = std::move(json_or_error.value()); + message_root[kResult] = kResultOk; + return message_root; +} + +Json::Value CreateInvalidAnswer(Error error) { + Json::Value message_root; + message_root[kMessageKeyType] = kMessageTypeAnswer; + message_root[kResult] = kResultError; + message_root[kErrorMessageBody][kErrorCode] = static_cast<int>(error.code()); + message_root[kErrorMessageBody][kErrorDescription] = error.message(); + + return message_root; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h new file mode 100644 index 00000000000..60b9a49479a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h @@ -0,0 +1,119 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_ANSWER_MESSAGES_H_ +#define CAST_STREAMING_ANSWER_MESSAGES_H_ + +#include <array> +#include <chrono> // NOLINT +#include <cstdint> +#include <initializer_list> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "cast/streaming/offer_messages.h" +#include "cast/streaming/ssrc.h" +#include "json/value.h" +#include "platform/base/error.h" +#include "util/simple_fraction.h" + +namespace openscreen { +namespace cast { + +struct AudioConstraints { + int max_sample_rate = 0; + int max_channels = 0; + // Technically optional, sender will assume 32kbps if omitted. + int min_bit_rate = 0; + int max_bit_rate = 0; + std::chrono::milliseconds max_delay = {}; + + ErrorOr<Json::Value> ToJson() const; +}; + +struct Dimensions { + int width = 0; + int height = 0; + SimpleFraction frame_rate; + + ErrorOr<Json::Value> ToJson() const; +}; + +struct VideoConstraints { + double max_pixels_per_second = {}; + Dimensions min_dimensions = {}; + Dimensions max_dimensions = {}; + // Technically optional, sender will assume 300kbps if omitted. + int min_bit_rate = 0; + int max_bit_rate = 0; + std::chrono::milliseconds max_delay = {}; + + ErrorOr<Json::Value> ToJson() const; +}; + +struct Constraints { + AudioConstraints audio; + VideoConstraints video; + + ErrorOr<Json::Value> ToJson() const; +}; + +// Decides whether the Sender scales and letterboxes content to 16:9, or if +// it may send video frames of any arbitrary size and the Receiver must +// handle the presentation details. +enum class AspectRatioConstraint : uint8_t { kVariable = 0, kFixed }; + +struct AspectRatio { + int width = 0; + int height = 0; +}; + +struct DisplayDescription { + // May exceed, be the same, or less than those mentioned in the + // video constraints. + Dimensions dimensions; + AspectRatio aspect_ratio = {}; + AspectRatioConstraint aspect_ratio_constraint = {}; + + ErrorOr<Json::Value> ToJson() const; +}; + +struct Answer { + CastMode cast_mode = {}; + int udp_port = 0; + std::vector<int> send_indexes; + std::vector<Ssrc> ssrcs; + + // Constraints and display descriptions are optional fields, and maybe null in + // the valid case. + absl::optional<Constraints> constraints; + absl::optional<DisplayDescription> display; + std::vector<int> receiver_rtcp_event_log; + std::vector<int> receiver_rtcp_dscp; + bool supports_wifi_status_reporting = false; + + // RTP extensions should be empty, but not null. + std::vector<std::string> rtp_extensions = {}; + + // ToJson performs a standard serialization, returning an error if this + // instance failed to serialize properly. + ErrorOr<Json::Value> ToJson() const; + + // In constrast to ToJson, ToAnswerMessage performs a successful serialization + // even if the answer object is malformed, by complying to the spec's + // error answer message format in this case. + Json::Value ToAnswerMessage() const; +}; + +// Helper method that creates an invalid Answer response. Exposed publicly +// here as it is called in ToAnswerMessage(), but can also be called by +// the receiver session. +Json::Value CreateInvalidAnswer(Error error); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_ANSWER_MESSAGES_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc new file mode 100644 index 00000000000..d1c708281d9 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc @@ -0,0 +1,180 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/answer_messages.h" + +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/json/json_serialization.h" + +namespace openscreen { +namespace cast { + +namespace { + +const Answer kValidAnswer{ + CastMode{CastMode::Type::kMirroring}, + 1234, // udp_port + std::vector<int>{1, 2, 3}, // send_indexes + std::vector<Ssrc>{123, 456}, // ssrcs + Constraints{ + AudioConstraints{ + 96000, // max_sample_rate + 7, // max_channels + 32000, // min_bit_rate + 96000, // max_bit_rate + std::chrono::milliseconds(2000) // max_delay + }, // audio + VideoConstraints{ + 40000.0, // max_pixels_per_second + Dimensions{ + 320, // width + 480, // height + SimpleFraction{15000, 101} // frame_rate + }, // min_dimensions + Dimensions{ + 1920, // width + 1080, // height + SimpleFraction{288, 2} // frame_rate + }, + 300000, // min_bit_rate + 144000000, // max_bit_rate + std::chrono::milliseconds(3000) // max_delay + } // video + }, // constraints + DisplayDescription{ + Dimensions{ + 640, // width + 480, // height + SimpleFraction{30, 1} // frame_rate + }, + AspectRatio{16, 9}, // aspect_ratio + AspectRatioConstraint::kFixed, // scaling + }, + std::vector<int>{7, 8, 9}, // receiver_rtcp_event_log + std::vector<int>{11, 12, 13}, // receiver_rtcp_dscp + true, // receiver_get_status + std::vector<std::string>{"foo", "bar"} // rtp_extensions +}; + +} // anonymous namespace + +TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) { + auto value_or_error = kValidAnswer.ToJson(); + EXPECT_TRUE(value_or_error.is_value()); + + Json::Value root = std::move(value_or_error.value()); + EXPECT_EQ(root["castMode"], "mirroring"); + EXPECT_EQ(root["udpPort"], 1234); + + Json::Value sendIndexes = std::move(root["sendIndexes"]); + EXPECT_EQ(sendIndexes.type(), Json::ValueType::arrayValue); + EXPECT_EQ(sendIndexes[0], 1); + EXPECT_EQ(sendIndexes[1], 2); + EXPECT_EQ(sendIndexes[2], 3); + + Json::Value ssrcs = std::move(root["ssrcs"]); + EXPECT_EQ(ssrcs.type(), Json::ValueType::arrayValue); + EXPECT_EQ(ssrcs[0], 123u); + EXPECT_EQ(ssrcs[1], 456u); + + Json::Value constraints = std::move(root["constraints"]); + Json::Value audio = std::move(constraints["audio"]); + EXPECT_EQ(audio.type(), Json::ValueType::objectValue); + EXPECT_EQ(audio["maxSampleRate"], 96000); + EXPECT_EQ(audio["maxChannels"], 7); + EXPECT_EQ(audio["minBitRate"], 32000); + EXPECT_EQ(audio["maxBitRate"], 96000); + EXPECT_EQ(audio["maxDelay"], 2000); + + Json::Value video = std::move(constraints["video"]); + EXPECT_EQ(video.type(), Json::ValueType::objectValue); + EXPECT_EQ(video["maxPixelsPerSecond"], 40000.0); + EXPECT_EQ(video["minBitRate"], 300000); + EXPECT_EQ(video["maxBitRate"], 144000000); + EXPECT_EQ(video["maxDelay"], 3000); + + Json::Value min_dimensions = std::move(video["minDimensions"]); + EXPECT_EQ(min_dimensions.type(), Json::ValueType::objectValue); + EXPECT_EQ(min_dimensions["width"], 320); + EXPECT_EQ(min_dimensions["height"], 480); + EXPECT_EQ(min_dimensions["frameRate"], "15000/101"); + + Json::Value max_dimensions = std::move(video["maxDimensions"]); + EXPECT_EQ(max_dimensions.type(), Json::ValueType::objectValue); + EXPECT_EQ(max_dimensions["width"], 1920); + EXPECT_EQ(max_dimensions["height"], 1080); + EXPECT_EQ(max_dimensions["frameRate"], "288/2"); + + Json::Value display = std::move(root["display"]); + EXPECT_EQ(display.type(), Json::ValueType::objectValue); + EXPECT_EQ(display["aspectRatio"], "16:9"); + EXPECT_EQ(display["scaling"], "sender"); + + Json::Value dimensions = std::move(display["dimensions"]); + EXPECT_EQ(dimensions.type(), Json::ValueType::objectValue); + EXPECT_EQ(dimensions["width"], 640); + EXPECT_EQ(dimensions["height"], 480); + EXPECT_EQ(dimensions["frameRate"], "30"); + + Json::Value receiver_rtcp_event_log = std::move(root["receiverRtcpEventLog"]); + EXPECT_EQ(receiver_rtcp_event_log.type(), Json::ValueType::arrayValue); + EXPECT_EQ(receiver_rtcp_event_log[0], 7); + EXPECT_EQ(receiver_rtcp_event_log[1], 8); + EXPECT_EQ(receiver_rtcp_event_log[2], 9); + + Json::Value receiver_rtcp_dscp = std::move(root["receiverRtcpDscp"]); + EXPECT_EQ(receiver_rtcp_dscp.type(), Json::ValueType::arrayValue); + EXPECT_EQ(receiver_rtcp_dscp[0], 11); + EXPECT_EQ(receiver_rtcp_dscp[1], 12); + EXPECT_EQ(receiver_rtcp_dscp[2], 13); + + EXPECT_EQ(root["receiverGetStatus"], true); + + Json::Value rtp_extensions = std::move(root["rtpExtensions"]); + EXPECT_EQ(rtp_extensions.type(), Json::ValueType::arrayValue); + EXPECT_EQ(rtp_extensions[0], "foo"); + EXPECT_EQ(rtp_extensions[1], "bar"); +} + +TEST(AnswerMessagesTest, InvalidDimensionsCauseError) { + Answer invalid_dimensions = kValidAnswer; + invalid_dimensions.display.value().dimensions.width = -1; + auto value_or_error = invalid_dimensions.ToJson(); + EXPECT_TRUE(value_or_error.is_error()); +} + +TEST(AnswerMessagesTest, InvalidAudioConstraintsCauseError) { + Answer invalid_audio = kValidAnswer; + invalid_audio.constraints.value().audio.max_bit_rate = + invalid_audio.constraints.value().audio.min_bit_rate - 1; + auto value_or_error = invalid_audio.ToJson(); + EXPECT_TRUE(value_or_error.is_error()); +} + +TEST(AnswerMessagesTest, InvalidVideoConstraintsCauseError) { + Answer invalid_video = kValidAnswer; + invalid_video.constraints.value().video.max_pixels_per_second = -1.0; + auto value_or_error = invalid_video.ToJson(); + EXPECT_TRUE(value_or_error.is_error()); +} + +TEST(AnswerMessagesTest, InvalidDisplayDescriptionsCauseError) { + Answer invalid_display = kValidAnswer; + invalid_display.display.value().aspect_ratio = {0, 0}; + auto value_or_error = invalid_display.ToJson(); + EXPECT_TRUE(value_or_error.is_error()); +} + +TEST(AnswerMessagesTest, InvalidUdpPortsCauseError) { + Answer invalid_port = kValidAnswer; + invalid_port.udp_port = 65536; + auto value_or_error = invalid_port.ToJson(); + EXPECT_TRUE(value_or_error.is_error()); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc new file mode 100644 index 00000000000..e3386b0071a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc @@ -0,0 +1,157 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/bandwidth_estimator.h" + +#include <algorithm> + +#include "util/logging.h" +#include "util/saturate_cast.h" + +namespace openscreen { +namespace cast { + +using openscreen::operator<<; // For std::chrono::duration logging. + +namespace { + +// Converts units from |bytes| per |time_window| number of Clock ticks into +// bits-per-second. +int ToClampedBitsPerSecond(int32_t bytes, Clock::duration time_window) { + OSP_DCHECK_GT(time_window, Clock::duration::zero()); + + // Divide |bytes| by |time_window| and scale the units to bits per second. + constexpr int64_t kBitsPerByte = 8; + constexpr int64_t kClockTicksPerSecond = + std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)) + .count(); + const int64_t bits = bytes * kBitsPerByte; + const int64_t bits_per_second = + (bits * kClockTicksPerSecond) / time_window.count(); + return saturate_cast<int>(bits_per_second); +} + +} // namespace + +BandwidthEstimator::BandwidthEstimator(int max_packets_per_timeslice, + Clock::duration timeslice_duration, + Clock::time_point start_time) + : max_packets_per_history_window_(max_packets_per_timeslice * + kNumTimeslices), + history_window_(timeslice_duration * kNumTimeslices), + burst_history_(timeslice_duration, start_time), + feedback_history_(timeslice_duration, start_time) { + OSP_DCHECK_GT(max_packets_per_timeslice, 0); + OSP_DCHECK_GT(timeslice_duration, Clock::duration::zero()); +} + +BandwidthEstimator::~BandwidthEstimator() = default; + +void BandwidthEstimator::OnBurstComplete(int num_packets_sent, + Clock::time_point when) { + OSP_DCHECK_GE(num_packets_sent, 0); + burst_history_.Accumulate(num_packets_sent, when); +} + +void BandwidthEstimator::OnRtcpReceived( + Clock::time_point arrival_time, + Clock::duration estimated_round_trip_time) { + OSP_DCHECK_GE(estimated_round_trip_time, Clock::duration::zero()); + // Move forward the feedback history tracking timeline to include the latest + // moment a packet could have left the Sender. + feedback_history_.AdvanceToIncludeTime(arrival_time - + estimated_round_trip_time); +} + +void BandwidthEstimator::OnPayloadReceived( + int payload_bytes_acknowledged, + Clock::time_point ack_arrival_time, + Clock::duration estimated_round_trip_time) { + OSP_DCHECK_GE(payload_bytes_acknowledged, 0); + OSP_DCHECK_GE(estimated_round_trip_time, Clock::duration::zero()); + // Track the bytes in terms of when the last packet was sent. + feedback_history_.Accumulate(payload_bytes_acknowledged, + ack_arrival_time - estimated_round_trip_time); +} + +int BandwidthEstimator::ComputeNetworkBandwidth() const { + // Determine whether the |burst_history_| time window overlaps with the + // |feedback_history_| time window by at least half. The time windows don't + // have to overlap entirely because the calculations are averaging all the + // measurements (i.e., recent typical behavior). Though, they should overlap + // by "enough" so that the measurements correlate "enough." + const Clock::time_point overlap_begin = + std::max(burst_history_.begin_time(), feedback_history_.begin_time()); + const Clock::time_point overlap_end = + std::min(burst_history_.end_time(), feedback_history_.end_time()); + if ((overlap_end - overlap_begin) < (history_window_ / 2)) { + return 0; + } + + const int32_t num_packets_transmitted = burst_history_.Sum(); + if (num_packets_transmitted <= 0) { + // Cannot estimate because there have been no transmissions recently. + return 0; + } + const Clock::duration transmit_duration = history_window_ * + num_packets_transmitted / + max_packets_per_history_window_; + const int32_t num_bytes_received = feedback_history_.Sum(); + return ToClampedBitsPerSecond(num_bytes_received, transmit_duration); +} + +// static +constexpr int BandwidthEstimator::kNumTimeslices; + +BandwidthEstimator::FlowTracker::FlowTracker(Clock::duration timeslice_duration, + Clock::time_point begin_time) + : timeslice_duration_(timeslice_duration), begin_time_(begin_time) {} + +BandwidthEstimator::FlowTracker::~FlowTracker() = default; + +void BandwidthEstimator::FlowTracker::AdvanceToIncludeTime( + Clock::time_point until) { + if (until < end_time()) { + return; // Not advancing. + } + + // Step forward in time, at timeslice granularity. + const int64_t num_periods = 1 + (until - end_time()) / timeslice_duration_; + begin_time_ += num_periods * timeslice_duration_; + + // Shift the ring elements, discarding N oldest timeslices, and creating N new + // ones initialized to zero. + const int shift_count = std::min<int64_t>(num_periods, kNumTimeslices); + for (int i = 0; i < shift_count; ++i) { + history_ring_[tail_++] = 0; + } +} + +void BandwidthEstimator::FlowTracker::Accumulate(int32_t amount, + Clock::time_point when) { + if (when < begin_time_) { + return; // Ignore a data point that is already too old. + } + + AdvanceToIncludeTime(when); + + // Because of the AdvanceToIncludeTime() call just made, the offset/index + // calculations here are guaranteed to point to a valid element in the + // |history_ring_|. + const int64_t offset_from_first = (when - begin_time_) / timeslice_duration_; + const index_mod_256_t ring_index = tail_ + offset_from_first; + int32_t& timeslice = history_ring_[ring_index]; + timeslice = saturate_cast<int32_t>(int64_t{timeslice} + amount); +} + +int32_t BandwidthEstimator::FlowTracker::Sum() const { + int64_t result = 0; + for (int32_t amount : history_ring_) { + result += amount; + } + return saturate_cast<int32_t>(result); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h new file mode 100644 index 00000000000..f8ce7af2088 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h @@ -0,0 +1,170 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_ +#define CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_ + +#include <stdint.h> + +#include <limits> + +#include "platform/api/time.h" + +namespace openscreen { +namespace cast { + +// Tracks send attempts and successful receives, and then computes a total +// network bandwith estimate. +// +// Two metrics are tracked by the BandwidthEstimator, over a "recent history" +// time window: +// +// 1. The number of packets sent during bursts (see SenderPacketRouter for +// explanation of what a "burst" is). These track when the network was +// actually in-use for transmission and the magnitude of each burst. When +// computing bandwidth, the estimator assumes the timeslices where the +// network was not in-use could have been used to send even more bytes at +// the same rate. +// +// 2. Successful receipt of payload bytes over time, or a lack thereof. +// Packets that include acknowledgements from the Receivers are providing +// proof of the successful receipt of payload bytes. All other packets +// provide proof of network connectivity over time, and are used to +// identify periods of time where nothing was received. +// +// The BandwidthEstimator assumes a simplified model for streaming over the +// network. The model does not include any detailed knowledge about things like +// protocol overhead, packet re-transmits, parasitic bufferring, network +// reliability, etc. Instead, it automatically accounts for all such things by +// looking at what's actually leaving the Senders and what's actually making it +// to the Receivers. +// +// This simplified model does produce some known inaccuracies in the resulting +// estimations. If no data has recently been transmitted (or been received), +// estimations cannot be provided. If the transmission rate is near (or +// exceeding) the network's capacity, the estimations will be very accurate. In +// between those two extremes, the logic will tend to under-estimate the +// network's capacity. However, those under-estimates will still be far larger +// than the current transmission rate. +// +// Thus, these estimates can be used effectively as a control signal for +// congestion control in upstream code modules. The logic computing the media's +// encoding target bitrate should be adjusted in realtime using a TCP-like +// congestion control algorithm: +// +// 1. When the estimated bitrate is less than the current encoding target +// bitrate, aggressively and immediately decrease the encoding bitrate. +// +// 2. When the estimated bitrate is more than the current encoding target +// bitrate, gradually increase the encoding bitrate (up to the maximum +// that is reasonable for the application). +class BandwidthEstimator { + public: + // |max_packets_per_timeslice| and |timeslice_duration| should match the burst + // configuration in SenderPacketRouter. |start_time| should be a recent + // point-in-time before the first packet is sent. + BandwidthEstimator(int max_packets_per_timeslice, + Clock::duration timeslice_duration, + Clock::time_point start_time); + + ~BandwidthEstimator(); + + // Returns the duration of the fixed, recent-history time window over which + // data flows are being tracked. + Clock::duration history_window() const { return history_window_; } + + // Records |when| burst-sending was active or inactive. For the active case, + // |num_packets_sent| should include all network packets sent, including + // non-payload packets (since both affect the modeled utilization/capacity). + // For the inactive case, this method should be called with zero for + // |num_packets_sent|. + void OnBurstComplete(int num_packets_sent, Clock::time_point when); + + // Records when a RTCP packet was received. It's important for Senders to call + // this any time a packet comes in from the Receivers, even if no payload is + // being acknowledged, since the time windows of "nothing successfully + // received" is also important information to track. + void OnRtcpReceived(Clock::time_point arrival_time, + Clock::duration estimated_round_trip_time); + + // Records that some number of payload bytes has been acknowledged (i.e., + // successfully received). + void OnPayloadReceived(int payload_bytes_acknowledged, + Clock::time_point ack_arrival_time, + Clock::duration estimated_round_trip_time); + + // Computes the current network bandwith estimate. Returns 0 if this cannot be + // determined due to a lack of sufficiently-recent data. + int ComputeNetworkBandwidth() const; + + private: + // FlowTracker (below) manages a ring buffer of size 256. It simplifies the + // index calculations to use an integer data type where all arithmetic is mod + // 256. + using index_mod_256_t = uint8_t; + static constexpr int kNumTimeslices = + static_cast<int>(std::numeric_limits<index_mod_256_t>::max()) + 1; + + // Tracks volume (e.g., the total number of payload bytes) over a fixed + // recent-history time window. The time window is divided up into a number of + // identical timeslices, each of which represents the total number of bytes + // that flowed during a certain period of time. The data is accumulated in + // ring buffer elements so that old data points drop-off as newer ones (that + // move the history window forward) are added. + class FlowTracker { + public: + FlowTracker(Clock::duration timeslice_duration, + Clock::time_point begin_time); + ~FlowTracker(); + + Clock::time_point begin_time() const { return begin_time_; } + Clock::time_point end_time() const { + return begin_time_ + timeslice_duration_ * kNumTimeslices; + } + + // Advance the end of the time window being tracked such that the + // most-recent timeslice includes |until|. Too-old timeslices are dropped + // and new ones are initialized to a zero amount. + void AdvanceToIncludeTime(Clock::time_point until); + + // Accumulate the given |amount| into the timeslice that includes |when|. + void Accumulate(int32_t amount, Clock::time_point when); + + // Return the sum of all the amounts in recent history. This clamps to the + // valid range of int32_t, if necessary. + int32_t Sum() const; + + private: + const Clock::duration timeslice_duration_; + + // The beginning of the oldest timeslice in the recent-history time window, + // the one pointed to by |tail_|. + Clock::time_point begin_time_; + + // A ring buffer tracking the accumulated amount for each timeslice. + int32_t history_ring_[kNumTimeslices]{}; + + // The index of the oldest timeslice in the |history_ring_|. This can also + // be thought of, equivalently, as the index just after the most-recent + // timeslice. + index_mod_256_t tail_ = 0; + }; + + // The maximum number of packet sends that could possibly be attempted during + // the recent-history time window. + const int max_packets_per_history_window_; + + // The range of time being tracked. + const Clock::duration history_window_; + + // History tracking for send attempts, and success feeback. These timeseries + // are in terms of when packets have left the Senders. + FlowTracker burst_history_; + FlowTracker feedback_history_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc new file mode 100644 index 00000000000..d6d8570a0a5 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc @@ -0,0 +1,232 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/bandwidth_estimator.h" + +#include <limits> +#include <random> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/api/time.h" + +namespace openscreen { +namespace cast { +namespace { + +using std::chrono::duration_cast; +using std::chrono::milliseconds; +using std::chrono::seconds; + +using openscreen::operator<<; // For std::chrono::duration gtest pretty-print. + +// BandwidthEstimator configuration common to all tests. +constexpr int kMaxPacketsPerTimeslice = 10; +constexpr Clock::duration kTimesliceDuration = milliseconds(10); +constexpr int kTimeslicesPerSecond = seconds(1) / kTimesliceDuration; + +// Use a fake, fixed start time. +constexpr Clock::time_point kStartTime = + Clock::time_point() + Clock::duration(1234567890); + +// Range of "fuzz" to add to timestamps in BandwidthEstimatorTest::AddFuzz(). +constexpr Clock::duration kMaxFuzzOffset = milliseconds(15); +constexpr int kFuzzLowerBoundClockTicks = (-kMaxFuzzOffset).count(); +constexpr int kFuzzUpperBoundClockTicks = kMaxFuzzOffset.count(); + +class BandwidthEstimatorTest : public testing::Test { + public: + BandwidthEstimatorTest() + : estimator_(kMaxPacketsPerTimeslice, kTimesliceDuration, kStartTime) {} + + BandwidthEstimator* estimator() { return &estimator_; } + + // Returns |t| plus or minus |kMaxFuzzOffset|. + Clock::time_point AddFuzz(Clock::time_point t) { + return t + Clock::duration(distribution_(rand_)); + } + + private: + BandwidthEstimator estimator_; + + // These are used to generate random values for AddFuzz(). + static constexpr std::minstd_rand::result_type kRandSeed = + kStartTime.time_since_epoch().count(); + std::minstd_rand rand_{kRandSeed}; + std::uniform_int_distribution<int> distribution_{kFuzzLowerBoundClockTicks, + kFuzzUpperBoundClockTicks}; +}; + +// Tests that, without any data, there won't be any estimates. +TEST_F(BandwidthEstimatorTest, DoesNotEstimateWithoutAnyInput) { + EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth()); +} + +// Tests the case where packets are being sent, but the Receiver hasn't provided +// feedback (e.g., due to a network blackout). +TEST_F(BandwidthEstimatorTest, DoesNotEstimateWithoutFeedback) { + Clock::time_point now = kStartTime; + for (int i = 0; i < 3; ++i) { + const Clock::time_point end = now + estimator()->history_window(); + for (; now < end; now += kTimesliceDuration) { + estimator()->OnBurstComplete(i, now); + EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth()); + } + now = end; + } +} + +// Tests the case where packets are being sent, and a connection to the Receiver +// has been confirmed (because RTCP packets are coming in), but the Receiver has +// not successfully received anything. +TEST_F(BandwidthEstimatorTest, DoesNotEstimateIfNothingSuccessfullyReceived) { + const Clock::duration kRoundTripTime = milliseconds(1); + + Clock::time_point now = kStartTime; + for (int i = 0; i < 3; ++i) { + const Clock::time_point end = now + estimator()->history_window(); + for (; now < end; now += kTimesliceDuration) { + estimator()->OnBurstComplete(i, now); + estimator()->OnRtcpReceived(now + kRoundTripTime, kRoundTripTime); + EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth()); + } + now = end; + } +} + +// Tests that, when the Receiver successfully receives the payload bytes at a +// fixed rate, the network bandwidth estimates are based on the amount of time +// the Sender spent transmitting. +TEST_F(BandwidthEstimatorTest, EstimatesAtVariousBurstSaturations) { + // These determine how many packets to burst in the simulation below. + constexpr int kDivisors[] = { + 1, // Burst 100% of max possible packets. + 2, // Burst 50% of max possible packets. + 5, // Burst 20% of max possible packets. + }; + + const Clock::duration kRoundTripTime = milliseconds(1); + + constexpr int kReceivedBytesPerSecond = 256000; + constexpr int kReceivedBytesPerTimeslice = + kReceivedBytesPerSecond / kTimeslicesPerSecond; + static_assert(kReceivedBytesPerSecond % kTimeslicesPerSecond == 0, + "Test expectations won't account for rounding errors."); + + ASSERT_EQ(0, estimator()->ComputeNetworkBandwidth()); + + // Simulate bursting at various rates, and confirm the bandwidth estimate is + // increasing for each burst rate. The estimate should be increasing because + // the total time spent transmitting is decreasing (while the same number of + // bytes are being received). + Clock::time_point now = kStartTime; + for (int divisor : kDivisors) { + SCOPED_TRACE(testing::Message() << "divisor=" << divisor); + + const Clock::time_point end = now + estimator()->history_window(); + for (; now < end; now += kTimesliceDuration) { + estimator()->OnBurstComplete(kMaxPacketsPerTimeslice / divisor, now); + const Clock::time_point rtcp_arrival_time = now + kRoundTripTime; + estimator()->OnPayloadReceived(kReceivedBytesPerTimeslice, + rtcp_arrival_time, kRoundTripTime); + estimator()->OnRtcpReceived(rtcp_arrival_time, kRoundTripTime); + } + now = end; + + const int estimate = estimator()->ComputeNetworkBandwidth(); + EXPECT_EQ(divisor * kReceivedBytesPerSecond * CHAR_BIT, estimate); + } +} + +// Tests that magnitude of the network round trip times, as well as random +// variance in packet arrival times, do not have a significant effect on the +// bandwidth estimates. +TEST_F(BandwidthEstimatorTest, EstimatesIndependentOfFeedbackDelays) { + constexpr int kFactor = 2; + constexpr int kPacketsPerBurst = kMaxPacketsPerTimeslice / kFactor; + static_assert(kMaxPacketsPerTimeslice % kFactor == 0, "wanted exactly half"); + + constexpr milliseconds kRoundTripTimes[3] = {milliseconds(1), milliseconds(9), + milliseconds(42)}; + + constexpr int kReceivedBytesPerSecond = 2000000; + constexpr int kReceivedBytesPerTimeslice = + kReceivedBytesPerSecond / kTimeslicesPerSecond; + + // An arbitrary threshold. Sources of error include anything that would place + // byte flows outside the history window (e.g., AddFuzz(), or the 42ms round + // trip time). + constexpr int kMaxErrorPercent = 3; + + Clock::time_point now = kStartTime; + for (Clock::duration round_trip_time : kRoundTripTimes) { + SCOPED_TRACE(testing::Message() + << "round_trip_time=" << round_trip_time.count()); + + const Clock::time_point end = now + estimator()->history_window(); + for (; now < end; now += kTimesliceDuration) { + estimator()->OnBurstComplete(kPacketsPerBurst, now); + const Clock::time_point rtcp_arrival_time = + AddFuzz(now + round_trip_time); + estimator()->OnPayloadReceived(kReceivedBytesPerTimeslice, + rtcp_arrival_time, round_trip_time); + estimator()->OnRtcpReceived(rtcp_arrival_time, round_trip_time); + } + now = end; + + constexpr int kExpectedEstimate = + kFactor * kReceivedBytesPerSecond * CHAR_BIT; + constexpr int kMaxError = kExpectedEstimate * kMaxErrorPercent / 100; + EXPECT_NEAR(kExpectedEstimate, estimator()->ComputeNetworkBandwidth(), + kMaxError); + } +} + +// Tests that both the history tracking internal to BandwidthEstimator, as well +// as its computation of the bandwidth estimate, are both resistant to possible +// integer overflow cases. The internal implementation always clamps to the +// valid range of int. +TEST_F(BandwidthEstimatorTest, ClampsEstimateToMaxInt) { + constexpr int kPacketsPerBurst = kMaxPacketsPerTimeslice / 5; + static_assert(kMaxPacketsPerTimeslice % 5 == 0, "wanted exactly 20%"); + const Clock::duration kRoundTripTime = milliseconds(1); + + int last_estimate = estimator()->ComputeNetworkBandwidth(); + ASSERT_EQ(last_estimate, 0); + + // Simulate increasing numbers of bytes received per timeslice until it + // reaches values near INT_MAX. Along the way, the bandwidth estimates + // themselves should start clamping and, because of the fuzz added to RTCP + // arrival times, individual buckets in BandwidthEstimator::FlowTracker will + // occassionally be clamped too. + Clock::time_point now = kStartTime; + for (int bytes_received_per_timeslice = 1; + bytes_received_per_timeslice > 0 /* not overflowed past INT_MAX */; + bytes_received_per_timeslice *= 2) { + SCOPED_TRACE(testing::Message() << "bytes_received_per_timeslice=" + << bytes_received_per_timeslice); + + const Clock::time_point end = now + estimator()->history_window() / 4; + for (; now < end; now += kTimesliceDuration) { + estimator()->OnBurstComplete(kPacketsPerBurst, now); + const Clock::time_point rtcp_arrival_time = AddFuzz(now + kRoundTripTime); + estimator()->OnPayloadReceived(bytes_received_per_timeslice, + rtcp_arrival_time, kRoundTripTime); + estimator()->OnRtcpReceived(rtcp_arrival_time, kRoundTripTime); + } + now = end; + + const int estimate = estimator()->ComputeNetworkBandwidth(); + EXPECT_LE(last_estimate, estimate); + last_estimate = estimate; + } + + // Confirm there was a loop iteration at which the estimate reached INT_MAX + // and then stayed there for successive loop iterations. + EXPECT_EQ(std::numeric_limits<int>::max(), last_estimate); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc index bb829578c53..92072733657 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc @@ -7,13 +7,13 @@ #include <cmath> #include "util/logging.h" +#include "util/saturate_cast.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { -constexpr ClockDriftSmoother::Clock::time_point kNullTime = - ClockDriftSmoother::Clock::time_point::min(); +constexpr Clock::time_point kNullTime = Clock::time_point::min(); } ClockDriftSmoother::ClockDriftSmoother(Clock::duration time_constant) @@ -25,15 +25,10 @@ ClockDriftSmoother::ClockDriftSmoother(Clock::duration time_constant) ClockDriftSmoother::~ClockDriftSmoother() = default; -ClockDriftSmoother::Clock::duration ClockDriftSmoother::Current() const { +Clock::duration ClockDriftSmoother::Current() const { OSP_DCHECK(last_update_time_ != kNullTime); - const double rounded_estimate = std::round(estimated_tick_offset_); - if (rounded_estimate < Clock::duration::min().count()) { - return Clock::duration::min(); - } else if (rounded_estimate > Clock::duration::max().count()) { - return Clock::duration::max(); - } - return Clock::duration(static_cast<Clock::duration::rep>(rounded_estimate)); + return Clock::duration( + rounded_saturate_cast<Clock::duration::rep>(estimated_tick_offset_)); } void ClockDriftSmoother::Reset(Clock::time_point now, @@ -68,5 +63,5 @@ void ClockDriftSmoother::Update(Clock::time_point now, // static constexpr std::chrono::seconds ClockDriftSmoother::kDefaultTimeConstant; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h index d48d166c6ce..97fe1ec601b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h +++ b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h @@ -9,16 +9,14 @@ #include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { // Tracks the jitter and drift between clocks, providing a smoothed offset. // Internally, a Simple IIR filter is used to maintain a running average that // moves at a rate based on the passage of time. class ClockDriftSmoother { public: - using Clock = openscreen::platform::Clock; - // |time_constant| is the amount of time an impulse signal takes to decay by // ~62.6%. Interpretation: If the value passed to several Update() calls is // held constant for T seconds, then the running average will have moved @@ -54,7 +52,7 @@ class ClockDriftSmoother { double estimated_tick_offset_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_CLOCK_DRIFT_SMOOTHER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc index b3cba68b109..37625d051af 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc @@ -14,11 +14,8 @@ #include "util/logging.h" #include "util/std_util.h" -using openscreen::AreElementsSortedAndUnique; -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { CompoundRtcpBuilder::CompoundRtcpBuilder(RtcpSession* session) : session_(session) { @@ -293,7 +290,7 @@ void CompoundRtcpBuilder::AppendCastFeedbackAckFields( // Compute how many additional octets are needed. constexpr int kIncrement = sizeof(uint32_t); const int num_additional = - openscreen::DividePositivesRoundingUp( + DividePositivesRoundingUp( (octet_index + 1) - num_ack_bitvector_octets, kIncrement) * kIncrement; @@ -328,5 +325,5 @@ void CompoundRtcpBuilder::AppendCastFeedbackAckFields( acks_for_next_packet_.clear(); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h index 58bc62fba5b..787a6e5231f 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h @@ -16,8 +16,8 @@ #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtp_defines.h" +namespace openscreen { namespace cast { -namespace streaming { class RtcpSession; @@ -97,9 +97,8 @@ class CompoundRtcpBuilder { // should be monotonically increasing so the consuming side (the Sender) can // determine the chronological ordering of RTCP packets. The Sender might also // use this to estimate round-trip times over the network. - absl::Span<uint8_t> BuildPacket( - openscreen::platform::Clock::time_point send_time, - absl::Span<uint8_t> buffer); + absl::Span<uint8_t> BuildPacket(Clock::time_point send_time, + absl::Span<uint8_t> buffer); // The required buffer size to be provided to BuildPacket(). This accounts for // all the possible headers and report structures that might be included, @@ -111,9 +110,8 @@ class CompoundRtcpBuilder { // Helper methods called by BuildPacket() to append one RTCP packet to the // |buffer| that will ultimately contain a "compound RTCP packet." void AppendReceiverReportPacket(absl::Span<uint8_t>* buffer); - void AppendReceiverReferenceTimeReportPacket( - openscreen::platform::Clock::time_point send_time, - absl::Span<uint8_t>* buffer); + void AppendReceiverReferenceTimeReportPacket(Clock::time_point send_time, + absl::Span<uint8_t>* buffer); void AppendPictureLossIndicatorPacket(absl::Span<uint8_t>* buffer); void AppendCastFeedbackPacket(absl::Span<uint8_t>* buffer); int AppendCastFeedbackLossFields(absl::Span<uint8_t>* buffer); @@ -122,7 +120,7 @@ class CompoundRtcpBuilder { RtcpSession* const session_; // Data to include in the next built RTCP packet. - FrameId checkpoint_frame_id_ = FrameId::first() - 1; + FrameId checkpoint_frame_id_ = FrameId::leader(); std::chrono::milliseconds playout_delay_ = kDefaultTargetPlayoutDelay; absl::optional<RtcpReportBlock> receiver_report_for_next_packet_; std::vector<PacketNack> nacks_for_next_packet_; @@ -134,7 +132,7 @@ class CompoundRtcpBuilder { uint8_t feedback_count_ = 0; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_COMPOUND_RTCP_BUILDER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc index ab6f8c00eab..969056cfa40 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc @@ -15,16 +15,14 @@ #include "gtest/gtest.h" #include "platform/api/time.h" -using openscreen::platform::Clock; - using testing::_; using testing::Invoke; using testing::Mock; using testing::SaveArg; using testing::StrictMock; +namespace openscreen { namespace cast { -namespace streaming { namespace { constexpr Ssrc kSenderSsrc{1}; @@ -369,5 +367,5 @@ TEST_F(CompoundRtcpBuilderTest, WithEverythingThatCanFit) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc index af5336ea71c..260989b0cb1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc @@ -11,10 +11,8 @@ #include "util/logging.h" #include "util/std_util.h" -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -185,7 +183,7 @@ bool CompoundRtcpParser::Parse(absl::Span<const uint8_t> buffer, client_->OnReceiverCheckpoint(checkpoint_frame_id, target_playout_delay); } if (!received_frames.empty()) { - OSP_DCHECK(openscreen::AreElementsSortedAndUnique(received_frames)); + OSP_DCHECK(AreElementsSortedAndUnique(received_frames)); client_->OnReceiverHasFrames(std::move(received_frames)); } CanonicalizePacketNackVector(&packet_nacks); @@ -338,7 +336,7 @@ bool CompoundRtcpParser::ParseExtendedReports( return false; // Length field must always be 2 words. } *receiver_reference_time = session_->ntp_converter().ToLocalTime( - openscreen::ReadBigEndian<uint64_t>(in.data())); + ReadBigEndian<uint64_t>(in.data())); } else { // Ignore any other type of extended report. } @@ -377,5 +375,5 @@ void CompoundRtcpParser::Client::OnReceiverHasFrames( void CompoundRtcpParser::Client::OnReceiverIsMissingPackets( std::vector<PacketNack> nacks) {} -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h index c9b71594a7a..c74bb3ec010 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h @@ -14,8 +14,8 @@ #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtp_defines.h" +namespace openscreen { namespace cast { -namespace streaming { class RtcpSession; @@ -41,7 +41,7 @@ class CompoundRtcpParser { // Called when a Receiver Reference Time Report has been parsed. virtual void OnReceiverReferenceTimeAdvanced( - openscreen::platform::Clock::time_point reference_time); + Clock::time_point reference_time); // Called when a Receiver Report with a Report Block has been parsed. virtual void OnReceiverReport(const RtcpReportBlock& receiver_report); @@ -101,9 +101,8 @@ class CompoundRtcpParser { std::chrono::milliseconds* target_playout_delay, std::vector<FrameId>* received_frames, std::vector<PacketNack>* packet_nacks); - bool ParseExtendedReports( - absl::Span<const uint8_t> in, - openscreen::platform::Clock::time_point* receiver_reference_time); + bool ParseExtendedReports(absl::Span<const uint8_t> in, + Clock::time_point* receiver_reference_time); bool ParsePictureLossIndicator(absl::Span<const uint8_t> in, bool* picture_loss_indicator); @@ -113,10 +112,10 @@ class CompoundRtcpParser { // Tracks the latest timestamp seen from any Receiver Reference Time Report, // and uses this to ignore stale RTCP packets that arrived out-of-order and/or // late from the network. - openscreen::platform::Clock::time_point latest_receiver_timestamp_; + Clock::time_point latest_receiver_timestamp_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_COMPOUND_RTCP_PARSER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc index af5823df945..bb3dd179b4b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc @@ -9,10 +9,10 @@ #include "cast/streaming/rtcp_session.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - using cast::streaming::CompoundRtcpParser; - using cast::streaming::FrameId; - using cast::streaming::RtcpSession; - using cast::streaming::Ssrc; + using openscreen::cast::CompoundRtcpParser; + using openscreen::cast::FrameId; + using openscreen::cast::RtcpSession; + using openscreen::cast::Ssrc; constexpr Ssrc kSenderSsrcInSeedCorpus = 1; constexpr Ssrc kReceiverSsrcInSeedCorpus = 2; @@ -25,7 +25,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wexit-time-destructors" static RtcpSession session(kSenderSsrcInSeedCorpus, kReceiverSsrcInSeedCorpus, - openscreen::platform::Clock::time_point{}); + openscreen::Clock::time_point{}); static CompoundRtcpParser::Client client_that_ignores_everything; static CompoundRtcpParser parser(&session, &client_that_ignores_everything); #pragma clang diagnostic pop diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc index 863953afad5..9f8e3c50988 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc @@ -18,8 +18,8 @@ using testing::Mock; using testing::SaveArg; using testing::StrictMock; +namespace openscreen { namespace cast { -namespace streaming { namespace { constexpr Ssrc kSenderSsrc{1}; @@ -32,8 +32,7 @@ class CompoundRtcpParserTest : public testing::Test { CompoundRtcpParser* parser() { return &parser_; } private: - RtcpSession session_{kSenderSsrc, kReceiverSsrc, - openscreen::platform::Clock::now()}; + RtcpSession session_{kSenderSsrc, kReceiverSsrc, Clock::now()}; StrictMock<MockCompoundRtcpParserClient> client_; CompoundRtcpParser parser_{&session_, &client_}; }; @@ -404,5 +403,5 @@ TEST_F(CompoundRtcpParserTest, ParsesFeedbackWithAcks) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/constants.h b/chromium/third_party/openscreen/src/cast/streaming/constants.h index d67ba3e9164..4a8d526e072 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/constants.h +++ b/chromium/third_party/openscreen/src/cast/streaming/constants.h @@ -14,13 +14,23 @@ #include <chrono> #include <ratio> +namespace openscreen { namespace cast { -namespace streaming { // Default target playout delay. The playout delay is the window of time between // capture from the source until presentation at the receiver. constexpr std::chrono::milliseconds kDefaultTargetPlayoutDelay(400); +// Default UDP port, bound at the Receiver, for Cast Streaming. An +// implementation is required to use the port specified by the Receiver in its +// ANSWER control message, which may or may not match this port number here. +constexpr int kDefaultCastStreamingPort = 2344; + +// Default TCP port, bound at the TLS server socket level, for Cast Streaming. +// An implementation must use the port specified in the DNS-SD published record +// for connecting over TLS, which may or may not match this port number here. +constexpr int kDefaultCastPort = 8010; + // Target number of milliseconds between the sending of RTCP reports. Both // senders and receivers regularly send RTCP reports to their peer. constexpr std::chrono::milliseconds kRtcpReportInterval(500); @@ -34,11 +44,14 @@ constexpr std::chrono::milliseconds kRtcpReportInterval(500); // logic can handle wrap around and compare two frame IDs meaningfully. constexpr int kMaxUnackedFrames = 120; +// The network must support a packet size of at least this many bytes. +constexpr int kRequiredNetworkPacketSize = 256; + // The spec declares RTP timestamps must always have a timebase of 90000 ticks // per second for video. using kVideoTimebase = std::ratio<1, 90000>; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_CONSTANTS_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc index ebd6cc753f9..4fb832bd1aa 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc @@ -4,8 +4,8 @@ #include "cast/streaming/encoded_frame.h" +namespace openscreen { namespace cast { -namespace streaming { EncodedFrame::EncodedFrame() = default; EncodedFrame::~EncodedFrame() = default; @@ -22,5 +22,5 @@ void EncodedFrame::CopyMetadataTo(EncodedFrame* dest) const { dest->new_playout_delay = this->new_playout_delay; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h index b42b6340d88..40fbda81ca1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h +++ b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h @@ -16,8 +16,8 @@ #include "platform/api/time.h" #include "platform/base/macros.h" +namespace openscreen { namespace cast { -namespace streaming { // A combination of metadata and data for one encoded frame. This can contain // audio data or video data or other. @@ -78,7 +78,7 @@ struct EncodedFrame { // (see |rtp_timestamp|, above). It is also meant to be used to synchronize // the presentation of multiple streams (e.g., audio and video), commonly // known as "lip-sync." It is NOT meant to be a mandatory/exact playout time. - openscreen::platform::Clock::time_point reference_time; + Clock::time_point reference_time; // Playout delay for this and all future frames. Used by the Adaptive // Playout delay extension. Non-positive values means no change. @@ -93,7 +93,7 @@ struct EncodedFrame { OSP_DISALLOW_COPY_AND_ASSIGN(EncodedFrame); }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_ENCODED_FRAME_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.cc b/chromium/third_party/openscreen/src/cast/streaming/environment.cc index 954a2e0c535..b3a6c7b1673 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/environment.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/environment.cc @@ -8,18 +8,8 @@ #include "platform/api/task_runner.h" #include "util/logging.h" -using openscreen::Error; -using openscreen::ErrorOr; -using openscreen::IPAddress; -using openscreen::IPEndpoint; -using openscreen::platform::Clock; -using openscreen::platform::ClockNowFunctionPtr; -using openscreen::platform::TaskRunner; -using openscreen::platform::UdpPacket; -using openscreen::platform::UdpSocket; - +namespace openscreen { namespace cast { -namespace streaming { Environment::Environment(ClockNowFunctionPtr now_function, TaskRunner* task_runner) @@ -124,10 +114,10 @@ void Environment::OnRead(UdpSocket* socket, // Ideally, the arrival time would come from the operating system's network // stack (e.g., by using the SO_TIMESTAMP sockopt on POSIX systems). However, // there would still be the problem of mapping the timestamp to a value in - // terms of platform::Clock. So, just sample the Clock here and call that the - // "arrival time." While this can add variance within the system, it should be - // minimal, assuming not too much time has elapsed between the actual packet - // receive event and the when this code here is executing. + // terms of Clock::time_point. So, just sample the Clock here and call that + // the "arrival time." While this can add variance within the system, it + // should be minimal, assuming not too much time has elapsed between the + // actual packet receive event and the when this code here is executing. const Clock::time_point arrival_time = now_function_(); UdpPacket packet = std::move(packet_or_error.value()); @@ -136,5 +126,5 @@ void Environment::OnRead(UdpSocket* socket, std::move(static_cast<std::vector<uint8_t>&>(packet))); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.h b/chromium/third_party/openscreen/src/cast/streaming/environment.h index 0bc99f2db58..278604b0d0f 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/environment.h +++ b/chromium/third_party/openscreen/src/cast/streaming/environment.h @@ -16,24 +16,17 @@ #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { -class TaskRunner; -} // namespace platform -} // namespace openscreen - namespace cast { -namespace streaming { // Provides the common environment for operating system resources shared by // multiple components. -class Environment : public openscreen::platform::UdpSocket::Client { +class Environment : public UdpSocket::Client { public: class PacketConsumer { public: - virtual void OnReceivedPacket( - const openscreen::IPEndpoint& source, - openscreen::platform::Clock::time_point arrival_time, - std::vector<uint8_t> packet) = 0; + virtual void OnReceivedPacket(const IPEndpoint& source, + Clock::time_point arrival_time, + std::vector<uint8_t> packet) = 0; protected: virtual ~PacketConsumer(); @@ -42,35 +35,34 @@ class Environment : public openscreen::platform::UdpSocket::Client { // Construct with the given clock source and TaskRunner. Creates and // internally-owns a UdpSocket, and immediately binds it to the given // |local_endpoint|. - Environment(openscreen::platform::ClockNowFunctionPtr now_function, - openscreen::platform::TaskRunner* task_runner, - const openscreen::IPEndpoint& local_endpoint); + Environment(ClockNowFunctionPtr now_function, + TaskRunner* task_runner, + const IPEndpoint& local_endpoint); ~Environment() override; - openscreen::platform::ClockNowFunctionPtr now_function() const { - return now_function_; - } - openscreen::platform::TaskRunner* task_runner() const { return task_runner_; } + ClockNowFunctionPtr now_function() const { return now_function_; } + Clock::time_point now() const { return now_function_(); } + TaskRunner* task_runner() const { return task_runner_; } // Returns the local endpoint the socket is bound to, or the zero IPEndpoint // if socket creation/binding failed. - openscreen::IPEndpoint GetBoundLocalEndpoint() const; + // + // Note: This method is virtual to allow unit tests to fake that there really + // is a bound socket. + virtual IPEndpoint GetBoundLocalEndpoint() const; // Set a handler function to run whenever non-recoverable socket errors occur. // If never set, the default is to emit log messages at error priority. - void set_socket_error_handler( - std::function<void(openscreen::Error)> handler) { + void set_socket_error_handler(std::function<void(Error)> handler) { socket_error_handler_ = handler; } // Get/Set the remote endpoint. This is separate from the constructor because // the remote endpoint is, in some cases, discovered only after receiving a // packet. - const openscreen::IPEndpoint& remote_endpoint() const { - return remote_endpoint_; - } - void set_remote_endpoint(const openscreen::IPEndpoint& endpoint) { + const IPEndpoint& remote_endpoint() const { return remote_endpoint_; } + void set_remote_endpoint(const IPEndpoint& endpoint) { remote_endpoint_ = endpoint; } @@ -98,34 +90,29 @@ class Environment : public openscreen::platform::UdpSocket::Client { // Common constructor that just stores the injected dependencies and does not // create a socket. Subclasses use this to provide an alternative packet // receive/send mechanism (e.g., for testing). - Environment(openscreen::platform::ClockNowFunctionPtr now_function, - openscreen::platform::TaskRunner* task_runner); + Environment(ClockNowFunctionPtr now_function, TaskRunner* task_runner); private: - // openscreen::platform::UdpSocket::Client implementation. - void OnError(openscreen::platform::UdpSocket* socket, - openscreen::Error error) final; - void OnSendError(openscreen::platform::UdpSocket* socket, - openscreen::Error error) final; - void OnRead(openscreen::platform::UdpSocket* socket, - openscreen::ErrorOr<openscreen::platform::UdpPacket> - packet_or_error) final; - - const openscreen::platform::ClockNowFunctionPtr now_function_; - openscreen::platform::TaskRunner* const task_runner_; + // UdpSocket::Client implementation. + void OnError(UdpSocket* socket, Error error) final; + void OnSendError(UdpSocket* socket, Error error) final; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet_or_error) final; + + const ClockNowFunctionPtr now_function_; + TaskRunner* const task_runner_; // The UDP socket bound to the local endpoint that was passed into the // constructor, or null if socket creation failed. - const std::unique_ptr<openscreen::platform::UdpSocket> socket_; + const std::unique_ptr<UdpSocket> socket_; // These are externally set/cleared. Behaviors are described in getter/setter // method comments above. - std::function<void(openscreen::Error)> socket_error_handler_; - openscreen::IPEndpoint remote_endpoint_{}; + std::function<void(Error)> socket_error_handler_; + IPEndpoint remote_endpoint_{}; PacketConsumer* packet_consumer_ = nullptr; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_ENVIRONMENT_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h index 0ac0dbebc8d..d418df647df 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h +++ b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h @@ -11,8 +11,8 @@ #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { // Abstract base template class for common "sequence value" data types such as // RtpTimeTicks, FrameId, or PacketId which generally increment/decrement in @@ -158,7 +158,7 @@ class ExpandedValueBase { FullWidthInteger value_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_EXPANDED_VALUE_BASE_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc index 220af2215a5..bcbf0d9f8f3 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc @@ -6,8 +6,8 @@ #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -104,5 +104,5 @@ TEST(ExpandedValueBaseTest, TruncationAndExpansion) { } } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc index aabbdd06738..dbd6fb4db9c 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc @@ -12,8 +12,8 @@ #include "cast/streaming/rtp_defines.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -156,5 +156,5 @@ void FrameCollector::Reset() { FrameCollector::PayloadChunk::PayloadChunk() = default; FrameCollector::PayloadChunk::~PayloadChunk() = default; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h index 7504d690958..ca68ab31b0b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h @@ -13,8 +13,8 @@ #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtp_packet_parser.h" +namespace openscreen { namespace cast { -namespace streaming { // Used by a Receiver to collect the parts of a frame, track what is // missing/complete, and assemble a complete frame. @@ -84,7 +84,7 @@ class FrameCollector { std::vector<PayloadChunk> chunks_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_FRAME_COLLECTOR_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc index 53e13811464..34448c24af4 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc @@ -15,8 +15,8 @@ #include "cast/streaming/rtp_time.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { const FrameId kSomeFrameId = FrameId::first() + 39; @@ -207,5 +207,5 @@ TEST(FrameCollectorTest, RejectsInvalidParts) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc index 416c2df1235..4d506d5eb7e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc @@ -12,8 +12,8 @@ #include "util/big_endian.h" #include "util/crypto/openssl_util.h" +namespace openscreen { namespace cast { -namespace streaming { EncryptedFrame::EncryptedFrame() { data = absl::Span<uint8_t>(owned_data_); @@ -91,8 +91,7 @@ void FrameCrypto::EncryptCommon(FrameId frame_id, std::array<uint8_t, 16> aes_nonce{/* zero initialized */}; static_assert(AES_BLOCK_SIZE == sizeof(aes_nonce), "AES_BLOCK_SIZE is not 16 bytes."); - openscreen::WriteBigEndian<uint32_t>(frame_id.lower_32_bits(), - aes_nonce.data() + 8); + WriteBigEndian<uint32_t>(frame_id.lower_32_bits(), aes_nonce.data() + 8); for (size_t i = 0; i < aes_nonce.size(); ++i) { aes_nonce[i] ^= cast_iv_mask_[i]; } @@ -116,5 +115,5 @@ std::array<uint8_t, 16> FrameCrypto::GenerateRandomBytes() { return result; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h index 693385ad311..35ee9787cb1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h @@ -16,8 +16,8 @@ #include "openssl/aes.h" #include "platform/base/macros.h" +namespace openscreen { namespace cast { -namespace streaming { class FrameCollector; class FrameCrypto; @@ -91,7 +91,7 @@ class FrameCrypto { absl::Span<uint8_t> out) const; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_FRAME_CRYPTO_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc index c5fcd400f25..7ac8d5cf2e5 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc @@ -10,8 +10,8 @@ #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { TEST(FrameCryptoTest, EncryptsAndDecryptsFrames) { @@ -76,5 +76,5 @@ TEST(FrameCryptoTest, EncryptsAndDecryptsFrames) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc index bf061fe8e12..3dd4951852b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc @@ -4,8 +4,8 @@ #include "cast/streaming/frame_id.h" +namespace openscreen { namespace cast { -namespace streaming { std::ostream& operator<<(std::ostream& out, const FrameId rhs) { out << "F"; @@ -14,5 +14,5 @@ std::ostream& operator<<(std::ostream& out, const FrameId rhs) { return out << rhs.value(); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_id.h b/chromium/third_party/openscreen/src/cast/streaming/frame_id.h index 3625242257a..d741a289b6d 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/frame_id.h +++ b/chromium/third_party/openscreen/src/cast/streaming/frame_id.h @@ -11,8 +11,8 @@ #include "cast/streaming/expanded_value_base.h" +namespace openscreen { namespace cast { -namespace streaming { // Forward declaration (see below). class FrameId; @@ -94,6 +94,16 @@ class FrameId : public ExpandedValueBase<int64_t, FrameId> { // The identifier for the first frame in a stream. static constexpr FrameId first() { return FrameId(0); } + // A virtual identifier, representing the frame before the first. There should + // never actually be a frame streamed with this identifier. Instead, this is + // used in various components to represent a "not yet seen/processed the first + // frame" state. + // + // The name "leader" comes from the terminology used in tape reels, which + // refers to the non-data-carrying segment of tape before the recording + // begins. + static constexpr FrameId leader() { return FrameId(-1); } + private: friend class ExpandedValueBase<int64_t, FrameId>; friend std::ostream& operator<<(std::ostream& out, const FrameId rhs); @@ -104,7 +114,7 @@ class FrameId : public ExpandedValueBase<int64_t, FrameId> { constexpr int64_t value() const { return value_; } }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_FRAME_ID_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/message_port.h b/chromium/third_party/openscreen/src/cast/streaming/message_port.h new file mode 100644 index 00000000000..f44a808dbf1 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/message_port.h @@ -0,0 +1,38 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_MESSAGE_PORT_H_ +#define CAST_STREAMING_MESSAGE_PORT_H_ + +#include "absl/strings/string_view.h" +#include "platform/base/error.h" + +namespace openscreen { +namespace cast { + +// This interface is intended to provide an abstraction for communicating +// cast messages across a pipe with guaranteed delivery. This is used to +// decouple the cast receiver session (and potentially other classes) from any +// network implementation. +class MessagePort { + public: + class Client { + public: + virtual void OnMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) = 0; + virtual void OnError(Error error) = 0; + }; + + virtual ~MessagePort() = default; + virtual void SetClient(Client* client) = 0; + virtual void PostMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) = 0; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_MESSAGE_PORT_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/message_util.h b/chromium/third_party/openscreen/src/cast/streaming/message_util.h new file mode 100644 index 00000000000..fbb3b21199f --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/message_util.h @@ -0,0 +1,77 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_MESSAGE_UTIL_H_ +#define CAST_STREAMING_MESSAGE_UTIL_H_ + +#include <vector> + +#include "absl/strings/string_view.h" +#include "json/value.h" +#include "platform/base/error.h" + +// This file contains helper methods that are used by both answer and offer +// messages, but should not be publicly exposed/consumed. +namespace openscreen { +namespace cast { + +inline Error CreateParseError(const std::string& type) { + return Error(Error::Code::kJsonParseError, "Failed to parse " + type); +} + +inline Error CreateParameterError(const std::string& type) { + return Error(Error::Code::kParameterInvalid, "Invalid parameter: " + type); +} + +inline ErrorOr<bool> ParseBool(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isBool()) { + return CreateParseError("bool field " + field); + } + return value.asBool(); +} + +inline ErrorOr<int> ParseInt(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isInt()) { + return CreateParseError("integer field: " + field); + } + return value.asInt(); +} + +inline ErrorOr<uint32_t> ParseUint(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isUInt()) { + return CreateParseError("unsigned integer field: " + field); + } + return value.asUInt(); +} + +inline ErrorOr<std::string> ParseString(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isString()) { + return CreateParseError("string field: " + field); + } + return value.asString(); +} + +// TODO(jophba): refactor to be on ErrorOr itself. +// Use this template for parsing only when there is a reasonable default +// for the type you are using, e.g. int or std::string. +template <typename T> +T ValueOrDefault(const ErrorOr<T>& value, T fallback = T{}) { + if (value.is_value()) { + return value.value(); + } + return fallback; +} + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_MESSAGE_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h b/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h index 0b16e599628..93d93f6762e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h +++ b/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h @@ -8,13 +8,13 @@ #include "cast/streaming/compound_rtcp_parser.h" #include "gmock/gmock.h" +namespace openscreen { namespace cast { -namespace streaming { class MockCompoundRtcpParserClient : public CompoundRtcpParser::Client { public: MOCK_METHOD1(OnReceiverReferenceTimeAdvanced, - void(openscreen::platform::Clock::time_point reference_time)); + void(Clock::time_point reference_time)); MOCK_METHOD1(OnReceiverReport, void(const RtcpReportBlock& receiver_report)); MOCK_METHOD0(OnReceiverIndicatesPictureLoss, void()); MOCK_METHOD2(OnReceiverCheckpoint, @@ -23,7 +23,7 @@ class MockCompoundRtcpParserClient : public CompoundRtcpParser::Client { MOCK_METHOD1(OnReceiverIsMissingPackets, void(std::vector<PacketNack> nacks)); }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_MOCK_COMPOUND_RTCP_PARSER_CLIENT_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc new file mode 100644 index 00000000000..9d712537ccd --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc @@ -0,0 +1,17 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/mock_environment.h" + +namespace openscreen { +namespace cast { + +MockEnvironment::MockEnvironment(ClockNowFunctionPtr now_function, + TaskRunner* task_runner) + : Environment(now_function, task_runner) {} + +MockEnvironment::~MockEnvironment() = default; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h new file mode 100644 index 00000000000..db66b141e31 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h @@ -0,0 +1,30 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_MOCK_ENVIRONMENT_H_ +#define CAST_STREAMING_MOCK_ENVIRONMENT_H_ + +#include "cast/streaming/environment.h" +#include "gmock/gmock.h" + +namespace openscreen { +namespace cast { + +// An Environment that can intercept all packet sends, for unit testing. +class MockEnvironment : public Environment { + public: + MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner); + ~MockEnvironment() override; + + // Used to return fake values, to simulate a bound socket for testing. + MOCK_METHOD(IPEndpoint, GetBoundLocalEndpoint, (), (const, override)); + + // Used for intercepting packet sends from the implementation under test. + MOCK_METHOD(void, SendPacket, (absl::Span<const uint8_t> packet), (override)); +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_MOCK_ENVIRONMENT_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc index d7072baf03e..0d59c0a8d4b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc @@ -6,11 +6,10 @@ #include "util/logging.h" -using openscreen::platform::Clock; using std::chrono::duration_cast; +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -54,5 +53,5 @@ Clock::time_point NtpTimeConverter::ToLocalTime(NtpTimestamp timestamp) const { return seconds_since_start + remainder; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h index f76dae4abe9..6b13132f8a9 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h +++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h @@ -9,8 +9,8 @@ #include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { // NTP timestamps are 64-bit timestamps that consist of two 32-bit parts: 1) The // number of seconds since 1 January 1900; and 2) The fraction of the second, @@ -41,24 +41,21 @@ constexpr NtpTimestamp AssembleNtpTimestamp(NtpSeconds seconds, static_cast<uint32_t>(fraction.count()); } -// Converts between openscreen::platform::Clock::time_points and NtpTimestamps. -// The class is instantiated with the current openscreen::platform::Clock time -// and the current wall clock time, and these are used to determine a fixed -// origin reference point for all conversions. Thus, to avoid introducing -// unintended timing-related behaviors, only one NtpTimeConverter instance -// should be used for converting all the NTP timestamps in the same streaming -// session. +// Converts between Clock::time_points and NtpTimestamps. The class is +// instantiated with the current Clock time and the current wall clock time, and +// these are used to determine a fixed origin reference point for all +// conversions. Thus, to avoid introducing unintended timing-related behaviors, +// only one NtpTimeConverter instance should be used for converting all the NTP +// timestamps in the same streaming session. class NtpTimeConverter { public: - NtpTimeConverter(openscreen::platform::Clock::time_point now, - std::chrono::seconds since_unix_epoch = - openscreen::platform::GetWallTimeSinceUnixEpoch()); + NtpTimeConverter( + Clock::time_point now, + std::chrono::seconds since_unix_epoch = GetWallTimeSinceUnixEpoch()); ~NtpTimeConverter(); - NtpTimestamp ToNtpTimestamp( - openscreen::platform::Clock::time_point time_point) const; - openscreen::platform::Clock::time_point ToLocalTime( - NtpTimestamp timestamp) const; + NtpTimestamp ToNtpTimestamp(Clock::time_point time_point) const; + Clock::time_point ToLocalTime(NtpTimestamp timestamp) const; private: // The time point on the platform clock's timeline that corresponds to @@ -68,11 +65,11 @@ class NtpTimeConverter { // can be off (with respect to each other) by even a large amount; and all // that matters is that time ticks forward at a reasonable pace from some // initial point. - const openscreen::platform::Clock::time_point start_time_; + const Clock::time_point start_time_; const NtpSeconds since_ntp_epoch_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_NTP_TIME_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc index 552673c03c8..1caa81ad475 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc @@ -6,13 +6,12 @@ #include "gtest/gtest.h" -using openscreen::platform::Clock; using std::chrono::duration_cast; using std::chrono::microseconds; using std::chrono::milliseconds; +namespace openscreen { namespace cast { -namespace streaming { TEST(NtpTimestampTest, SplitsIntoParts) { // 1 Jan 1900. @@ -73,8 +72,7 @@ TEST(NtpTimeConverterTest, ConvertsToNtpTimeAndBack) { // our core assumptions (or the design) about the time math are wrong and // should be looked into! const Clock::time_point steady_clock_start = Clock::now(); - const std::chrono::seconds wall_clock_start = - openscreen::platform::GetWallTimeSinceUnixEpoch(); + const std::chrono::seconds wall_clock_start = GetWallTimeSinceUnixEpoch(); SCOPED_TRACE(::testing::Message() << "steady_clock_start.time_since_epoch().count() is " << steady_clock_start.time_since_epoch().count() @@ -106,5 +104,5 @@ TEST(NtpTimeConverterTest, ConvertsToNtpTimeAndBack) { } } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc new file mode 100644 index 00000000000..f1da9c60eb3 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc @@ -0,0 +1,437 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/offer_messages.h" + +#include <inttypes.h> + +#include <string> +#include <utility> + +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/message_util.h" +#include "cast/streaming/receiver_session.h" +#include "platform/base/error.h" +#include "util/big_endian.h" +#include "util/json/json_serialization.h" +#include "util/logging.h" +#include "util/stringprintf.h" + +namespace openscreen { +namespace cast { + +namespace { + +constexpr char kSupportedStreams[] = "supportedStreams"; +constexpr char kAudioSourceType[] = "audio_source"; +constexpr char kVideoSourceType[] = "video_source"; +constexpr char kStreamType[] = "type"; + +ErrorOr<RtpPayloadType> ParseRtpPayloadType(const Json::Value& parent, + const std::string& field) { + auto t = ParseInt(parent, field); + if (!t) { + return t.error(); + } + + uint8_t t_small = t.value(); + if (t_small != t.value() || !IsRtpPayloadType(t_small)) { + return Error(Error::Code::kParameterInvalid, + "Received invalid RTP Payload Type."); + } + + return static_cast<RtpPayloadType>(t_small); +} + +ErrorOr<int> ParseRtpTimebase(const Json::Value& parent, + const std::string& field) { + auto error_or_raw = ParseString(parent, field); + if (!error_or_raw) { + return error_or_raw.error(); + } + + const auto fraction = SimpleFraction::FromString(error_or_raw.value()); + if (fraction.is_error() || !fraction.value().is_positive()) { + return CreateParseError("RTP timebase"); + } + // The spec demands a leading 1, so this isn't really a fraction. + OSP_DCHECK(fraction.value().numerator == 1); + return fraction.value().denominator; +} + +// For a hex byte, the conversion is 4 bits to 1 character, e.g. +// 0b11110001 becomes F1, so 1 byte is two characters. +constexpr int kHexDigitsPerByte = 2; +constexpr int kAesBytesSize = 16; +constexpr int kAesStringLength = kAesBytesSize * kHexDigitsPerByte; +ErrorOr<std::array<uint8_t, kAesBytesSize>> ParseAesHexBytes( + const Json::Value& parent, + const std::string& field) { + auto hex_string = ParseString(parent, field); + if (!hex_string) { + return hex_string.error(); + } + + constexpr int kHexDigitsPerScanField = 16; + constexpr int kNumScanFields = kAesStringLength / kHexDigitsPerScanField; + uint64_t quads[kNumScanFields]; + int chars_scanned; + if (hex_string.value().size() == kAesStringLength && + sscanf(hex_string.value().c_str(), "%16" SCNx64 "%16" SCNx64 "%n", + &quads[0], &quads[1], &chars_scanned) == kNumScanFields && + chars_scanned == kAesStringLength && + std::none_of(hex_string.value().begin(), hex_string.value().end(), + [](char c) { return std::isspace(c); })) { + std::array<uint8_t, kAesBytesSize> bytes; + WriteBigEndian(quads[0], bytes.data()); + WriteBigEndian(quads[1], bytes.data() + 8); + return bytes; + } + return CreateParseError("AES hex string bytes"); +} + +ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) { + auto index = ParseInt(value, "index"); + if (!index) { + return index.error(); + } + // If channel is omitted, the default value is used later. + auto channels = ParseInt(value, "channels"); + if (channels.is_value() && channels.value() <= 0) { + return CreateParameterError("channel"); + } + auto codec_name = ParseString(value, "codecName"); + if (!codec_name) { + return codec_name.error(); + } + auto rtp_profile = ParseString(value, "rtpProfile"); + if (!rtp_profile) { + return rtp_profile.error(); + } + auto rtp_payload_type = ParseRtpPayloadType(value, "rtpPayloadType"); + if (!rtp_payload_type) { + return rtp_payload_type.error(); + } + auto ssrc = ParseUint(value, "ssrc"); + if (!ssrc) { + return ssrc.error(); + } + auto aes_key = ParseAesHexBytes(value, "aesKey"); + if (!aes_key) { + return aes_key.error(); + } + auto aes_iv_mask = ParseAesHexBytes(value, "aesIvMask"); + if (!aes_iv_mask) { + return aes_iv_mask.error(); + } + auto rtp_timebase = ParseRtpTimebase(value, "timeBase"); + if (!rtp_timebase) { + return rtp_timebase.error(); + } + + auto target_delay = ParseInt(value, "targetDelay"); + std::chrono::milliseconds target_delay_ms = kDefaultTargetPlayoutDelay; + if (target_delay) { + auto d = std::chrono::milliseconds(target_delay.value()); + if (d >= kMinTargetPlayoutDelay && d <= kMaxTargetPlayoutDelay) { + target_delay_ms = d; + } else { + return CreateParameterError("target delay"); + } + } + + auto receiver_rtcp_event_log = ParseBool(value, "receiverRtcpEventLog"); + auto receiver_rtcp_dscp = ParseString(value, "receiverRtcpDscp"); + return Stream{index.value(), + type, + ValueOrDefault(channels, type == Stream::Type::kAudioSource + ? kDefaultNumAudioChannels + : kDefaultNumVideoChannels), + codec_name.value(), + rtp_payload_type.value(), + ssrc.value(), + target_delay_ms, + aes_key.value(), + aes_iv_mask.value(), + ValueOrDefault(receiver_rtcp_event_log), + ValueOrDefault(receiver_rtcp_dscp), + rtp_timebase.value()}; +} + +ErrorOr<AudioStream> ParseAudioStream(const Json::Value& value) { + auto stream = ParseStream(value, Stream::Type::kAudioSource); + if (!stream) { + return stream.error(); + } + auto bit_rate = ParseInt(value, "bitRate"); + if (!bit_rate) { + return bit_rate.error(); + } + // A bit rate of 0 is valid for some codec types, so we don't enforce here. + if (bit_rate.value() < 0) { + return CreateParameterError("bit rate"); + } + return AudioStream{stream.value(), bit_rate.value()}; +} + +ErrorOr<Resolution> ParseResolution(const Json::Value& value) { + auto width = ParseInt(value, "width"); + if (!width) { + return width.error(); + } + auto height = ParseInt(value, "height"); + if (!height) { + return height.error(); + } + if (width.value() <= 0 || height.value() <= 0) { + return CreateParameterError("resolution"); + } + return Resolution{width.value(), height.value()}; +} + +ErrorOr<std::vector<Resolution>> ParseResolutions(const Json::Value& parent, + const std::string& field) { + std::vector<Resolution> resolutions; + // Some legacy senders don't provide resolutions, so just return empty. + const Json::Value& value = parent[field]; + if (!value.isArray() || value.empty()) { + return resolutions; + } + + for (Json::ArrayIndex i = 0; i < value.size(); ++i) { + auto r = ParseResolution(value[i]); + if (!r) { + return r.error(); + } + resolutions.push_back(r.value()); + } + + return resolutions; +} + +ErrorOr<VideoStream> ParseVideoStream(const Json::Value& value) { + auto stream = ParseStream(value, Stream::Type::kVideoSource); + if (!stream) { + return stream.error(); + } + auto resolutions = ParseResolutions(value, "resolutions"); + if (!resolutions) { + return resolutions.error(); + } + + auto raw_max_frame_rate = ParseString(value, "maxFrameRate"); + SimpleFraction max_frame_rate{kDefaultMaxFrameRate, 1}; + if (raw_max_frame_rate.is_value()) { + auto parsed = SimpleFraction::FromString(raw_max_frame_rate.value()); + if (parsed.is_value() && parsed.value().is_positive()) { + max_frame_rate = parsed.value(); + } + } + + auto profile = ParseString(value, "profile"); + auto protection = ParseString(value, "protection"); + auto max_bit_rate = ParseInt(value, "maxBitRate"); + auto level = ParseString(value, "level"); + auto error_recovery_mode = ParseString(value, "errorRecoveryMode"); + return VideoStream{stream.value(), + max_frame_rate, + ValueOrDefault(max_bit_rate, 4 << 20), + ValueOrDefault(protection), + ValueOrDefault(profile), + ValueOrDefault(level), + resolutions.value(), + ValueOrDefault(error_recovery_mode)}; +} + +absl::string_view ToString(Stream::Type type) { + switch (type) { + case Stream::Type::kAudioSource: + return kAudioSourceType; + case Stream::Type::kVideoSource: + return kVideoSourceType; + default: { + OSP_NOTREACHED(); + return ""; + } + } +} + +} // namespace + +constexpr char kCastMirroring[] = "mirroring"; +constexpr char kCastRemoting[] = "remoting"; + +// static +CastMode CastMode::Parse(absl::string_view value) { + return (value == kCastRemoting) ? CastMode{CastMode::Type::kRemoting} + : CastMode{CastMode::Type::kMirroring}; +} + +ErrorOr<Json::Value> Stream::ToJson() const { + if (channels < 1 || index < 0 || codec_name.empty() || + target_delay.count() <= 0 || + target_delay.count() > std::numeric_limits<int>::max() || + rtp_timebase < 1) { + return CreateParameterError("Stream"); + } + + Json::Value root; + root["index"] = index; + root["type"] = std::string(ToString(type)); + root["channels"] = channels; + root["codecName"] = codec_name; + root["rtpPayloadType"] = static_cast<int>(rtp_payload_type); + // rtpProfile is technically required by the spec, although it is always set + // to cast. We set it here to be compliant with all spec implementers. + root["rtpProfile"] = "cast"; + static_assert(sizeof(ssrc) <= sizeof(Json::UInt), + "this code assumes Ssrc fits in a Json::UInt"); + root["ssrc"] = static_cast<Json::UInt>(ssrc); + root["targetDelay"] = static_cast<int>(target_delay.count()); + root["aesKey"] = HexEncode(aes_key); + root["aesIvMask"] = HexEncode(aes_iv_mask); + root["ReceiverRtcpEventLog"] = receiver_rtcp_event_log; + root["receiverRtcpDscp"] = receiver_rtcp_dscp; + root["timeBase"] = "1/" + std::to_string(rtp_timebase); + return root; +} + +std::string CastMode::ToString() const { + switch (type) { + case Type::kMirroring: + return kCastMirroring; + case Type::kRemoting: + return kCastRemoting; + default: + OSP_NOTREACHED(); + return ""; + } +} + +ErrorOr<Json::Value> AudioStream::ToJson() const { + // A bit rate of 0 is valid for some codec types, so we don't enforce here. + if (bit_rate < 0) { + return CreateParameterError("AudioStream"); + } + + auto error_or_stream = stream.ToJson(); + if (error_or_stream.is_error()) { + return error_or_stream; + } + + error_or_stream.value()["bitRate"] = bit_rate; + return error_or_stream; +} + +ErrorOr<Json::Value> Resolution::ToJson() const { + if (width <= 0 || height <= 0) { + return CreateParameterError("Resolution"); + } + + Json::Value root; + root["width"] = width; + root["height"] = height; + return root; +} + +ErrorOr<Json::Value> VideoStream::ToJson() const { + if (max_bit_rate <= 0 || !max_frame_rate.is_positive()) { + return CreateParameterError("VideoStream"); + } + + auto error_or_stream = stream.ToJson(); + if (error_or_stream.is_error()) { + return error_or_stream; + } + + auto& stream = error_or_stream.value(); + stream["maxFrameRate"] = max_frame_rate.ToString(); + stream["maxBitRate"] = max_bit_rate; + stream["protection"] = protection; + stream["profile"] = profile; + stream["level"] = level; + stream["errorRecoveryMode"] = error_recovery_mode; + + Json::Value rs; + for (auto resolution : resolutions) { + auto eoj = resolution.ToJson(); + if (eoj.is_error()) { + return eoj; + } + rs.append(eoj.value()); + } + stream["resolutions"] = std::move(rs); + return error_or_stream; +} + +// static +ErrorOr<Offer> Offer::Parse(const Json::Value& root) { + CastMode cast_mode = CastMode::Parse(root["castMode"].asString()); + + const ErrorOr<bool> get_status = ParseBool(root, "receiverGetStatus"); + + Json::Value supported_streams = root[kSupportedStreams]; + if (!supported_streams.isArray()) { + return CreateParseError("supported streams in offer"); + } + + std::vector<AudioStream> audio_streams; + std::vector<VideoStream> video_streams; + for (Json::ArrayIndex i = 0; i < supported_streams.size(); ++i) { + const Json::Value& fields = supported_streams[i]; + auto type = ParseString(fields, kStreamType); + if (!type) { + return type.error(); + } + + if (type.value() == kAudioSourceType) { + auto stream = ParseAudioStream(fields); + if (!stream) { + return stream.error(); + } + audio_streams.push_back(std::move(stream.value())); + } else if (type.value() == kVideoSourceType) { + auto stream = ParseVideoStream(fields); + if (!stream) { + return stream.error(); + } + video_streams.push_back(std::move(stream.value())); + } + } + + return Offer{cast_mode, ValueOrDefault(get_status), std::move(audio_streams), + std::move(video_streams)}; +} + +ErrorOr<Json::Value> Offer::ToJson() const { + Json::Value root; + root["castMode"] = cast_mode.ToString(); + root["receiverGetStatus"] = supports_wifi_status_reporting; + + Json::Value streams; + for (auto& as : audio_streams) { + auto eoj = as.ToJson(); + if (eoj.is_error()) { + return eoj; + } + streams.append(eoj.value()); + } + + for (auto& vs : video_streams) { + auto eoj = vs.ToJson(); + if (eoj.is_error()) { + return eoj; + } + streams.append(eoj.value()); + } + + root[kSupportedStreams] = std::move(streams); + return root; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h new file mode 100644 index 00000000000..319145bc3cf --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h @@ -0,0 +1,121 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_OFFER_MESSAGES_H_ +#define CAST_STREAMING_OFFER_MESSAGES_H_ + +#include <chrono> // NOLINT +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "cast/streaming/rtp_defines.h" +#include "cast/streaming/session_config.h" +#include "json/value.h" +#include "platform/base/error.h" +#include "util/simple_fraction.h" + +// This file contains the implementation of the Cast V2 Mirroring Control +// Protocol offer object definition. +namespace openscreen { +namespace cast { + +// If the target delay provided by the sender is not bounded by +// [kMinTargetDelay, kMaxTargetDelay], it will be set to +// kDefaultTargetPlayoutDelay. +constexpr auto kMinTargetPlayoutDelay = std::chrono::milliseconds(0); +constexpr auto kMaxTargetPlayoutDelay = std::chrono::milliseconds(2000); + +// If the sender provides an invalid maximum frame rate, it will +// be set to kDefaultMaxFrameRate. +constexpr int kDefaultMaxFrameRate = 30; + +constexpr int kDefaultNumVideoChannels = 1; +constexpr int kDefaultNumAudioChannels = 2; + +// A stream, as detailed by the CastV2 protocol spec, is a segment of an +// offer message specifically representing a configuration object for +// a codec and its related fields, such as maximum bit rate, time base, +// and other fields. +// Composed classes include AudioStream and VideoStream, which contain +// fields specific to audio and video respectively. +struct Stream { + enum class Type : uint8_t { kAudioSource, kVideoSource }; + + ErrorOr<Json::Value> ToJson() const; + + int index = 0; + Type type = {}; + + // Default channel count is 1, e.g. for video. + int channels = 0; + std::string codec_name = {}; + RtpPayloadType rtp_payload_type = {}; + Ssrc ssrc = {}; + std::chrono::milliseconds target_delay = {}; + + // AES Key and IV mask format is very strict: a 32 digit hex string that + // must be converted to a 16 digit byte array. + std::array<uint8_t, 16> aes_key = {}; + std::array<uint8_t, 16> aes_iv_mask = {}; + bool receiver_rtcp_event_log = {}; + std::string receiver_rtcp_dscp = {}; + int rtp_timebase = 0; +}; + +struct AudioStream { + ErrorOr<Json::Value> ToJson() const; + + Stream stream = {}; + int bit_rate = 0; +}; + +struct Resolution { + ErrorOr<Json::Value> ToJson() const; + + int width = 0; + int height = 0; +}; + +struct VideoStream { + ErrorOr<Json::Value> ToJson() const; + + Stream stream = {}; + SimpleFraction max_frame_rate; + int max_bit_rate = 0; + std::string protection = {}; + std::string profile = {}; + std::string level = {}; + std::vector<Resolution> resolutions = {}; + std::string error_recovery_mode = {}; +}; + +struct CastMode { + public: + enum class Type : uint8_t { kMirroring, kRemoting }; + + static CastMode Parse(absl::string_view value); + std::string ToString() const; + + // Default cast mode is mirroring. + Type type = Type::kMirroring; +}; + +struct Offer { + static ErrorOr<Offer> Parse(const Json::Value& root); + ErrorOr<Json::Value> ToJson() const; + + CastMode cast_mode = {}; + // This field is poorly named in the spec (receiverGetStatus), so we use + // a more descriptive name here. + bool supports_wifi_status_reporting = {}; + std::vector<AudioStream> audio_streams = {}; + std::vector<VideoStream> video_streams = {}; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_OFFER_MESSAGES_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc new file mode 100644 index 00000000000..4931927a5ae --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc @@ -0,0 +1,429 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/offer_messages.h" + +#include <utility> + +#include "cast/streaming/rtp_defines.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/json/json_serialization.h" + +using ::testing::ElementsAre; + +namespace openscreen { +namespace cast { + +namespace { + +constexpr char kValidOffer[] = R"({ + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + { + "index": 0, + "type": "video_source", + "codecName": "h264", + "rtpProfile": "cast", + "rtpPayloadType": 101, + "ssrc": 19088743, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "4", + "targetDelay": 200, + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1280, + "height": 720 + }, + { + "width": 640, + "height": 360 + }, + { + "width": 640, + "height": 480 + } + ] + }, + { + "index": 1, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088744, + "maxFrameRate": "30000/1001", + "targetDelay": 1000, + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "5", + "aesKey": "bbf109bf84513b456b13a184453b66ce", + "aesIvMask": "edaf9e4536e2b66191f560d9c04b2a69" + }, + { + "index": 2, + "type": "audio_source", + "codecName": "opus", + "targetDelay": 300, + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 4294967295, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + } + ] +})"; + +void ExpectFailureOnParse(absl::string_view body) { + ErrorOr<Json::Value> root = json::Parse(body); + ASSERT_TRUE(root.is_value()); + EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_error()); +} + +void ExpectEqualsValidOffer(const Offer& offer) { + EXPECT_EQ(CastMode::Type::kMirroring, offer.cast_mode.type); + EXPECT_EQ(true, offer.supports_wifi_status_reporting); + + // Verify list of video streams. + EXPECT_EQ(2u, offer.video_streams.size()); + const auto& video_streams = offer.video_streams; + + const bool flipped = video_streams[0].stream.index != 0; + const VideoStream& vs_one = flipped ? video_streams[1] : video_streams[0]; + const VideoStream& vs_two = flipped ? video_streams[0] : video_streams[1]; + + EXPECT_EQ(0, vs_one.stream.index); + EXPECT_EQ(1, vs_one.stream.channels); + EXPECT_EQ(Stream::Type::kVideoSource, vs_one.stream.type); + EXPECT_EQ("h264", vs_one.stream.codec_name); + EXPECT_EQ(RtpPayloadType::kVideoH264, vs_one.stream.rtp_payload_type); + EXPECT_EQ(19088743u, vs_one.stream.ssrc); + EXPECT_EQ((SimpleFraction{60000, 1000}), vs_one.max_frame_rate); + EXPECT_EQ(90000, vs_one.stream.rtp_timebase); + EXPECT_EQ(5000000, vs_one.max_bit_rate); + EXPECT_EQ("main", vs_one.profile); + EXPECT_EQ("4", vs_one.level); + EXPECT_THAT(vs_one.stream.aes_key, + ElementsAre(0x04, 0x0d, 0x75, 0x67, 0x91, 0x71, 0x1f, 0xd3, 0xad, + 0xb9, 0x39, 0x06, 0x6e, 0x6d, 0x86, 0x90)); + EXPECT_THAT(vs_one.stream.aes_iv_mask, + ElementsAre(0x9f, 0xf0, 0xf0, 0x22, 0xa9, 0x59, 0x15, 0x0e, 0x70, + 0xa2, 0xd0, 0x5a, 0x6c, 0x18, 0x4a, 0xed)); + + const auto& resolutions = vs_one.resolutions; + EXPECT_EQ(3u, resolutions.size()); + const Resolution& r_one = resolutions[0]; + EXPECT_EQ(1280, r_one.width); + EXPECT_EQ(720, r_one.height); + + const Resolution& r_two = resolutions[1]; + EXPECT_EQ(640, r_two.width); + EXPECT_EQ(360, r_two.height); + + const Resolution& r_three = resolutions[2]; + EXPECT_EQ(640, r_three.width); + EXPECT_EQ(480, r_three.height); + + EXPECT_EQ(1, vs_two.stream.index); + EXPECT_EQ(1, vs_two.stream.channels); + EXPECT_EQ(Stream::Type::kVideoSource, vs_two.stream.type); + EXPECT_EQ("vp8", vs_two.stream.codec_name); + EXPECT_EQ(RtpPayloadType::kVideoVp8, vs_two.stream.rtp_payload_type); + EXPECT_EQ(19088744u, vs_two.stream.ssrc); + EXPECT_EQ((SimpleFraction{30000, 1001}), vs_two.max_frame_rate); + EXPECT_EQ(90000, vs_two.stream.rtp_timebase); + EXPECT_EQ(5000000, vs_two.max_bit_rate); + EXPECT_EQ("main", vs_two.profile); + EXPECT_EQ("5", vs_two.level); + EXPECT_THAT(vs_two.stream.aes_key, + ElementsAre(0xbb, 0xf1, 0x09, 0xbf, 0x84, 0x51, 0x3b, 0x45, 0x6b, + 0x13, 0xa1, 0x84, 0x45, 0x3b, 0x66, 0xce)); + EXPECT_THAT(vs_two.stream.aes_iv_mask, + ElementsAre(0xed, 0xaf, 0x9e, 0x45, 0x36, 0xe2, 0xb6, 0x61, 0x91, + 0xf5, 0x60, 0xd9, 0xc0, 0x4b, 0x2a, 0x69)); + + const auto& resolutions_two = vs_two.resolutions; + EXPECT_EQ(0u, resolutions_two.size()); + + // Verify list of audio streams. + EXPECT_EQ(1u, offer.audio_streams.size()); + const AudioStream& as = offer.audio_streams[0]; + EXPECT_EQ(2, as.stream.index); + EXPECT_EQ(Stream::Type::kAudioSource, as.stream.type); + EXPECT_EQ("opus", as.stream.codec_name); + EXPECT_EQ(RtpPayloadType::kAudioOpus, as.stream.rtp_payload_type); + EXPECT_EQ(std::numeric_limits<Ssrc>::max(), as.stream.ssrc); + EXPECT_EQ(124000, as.bit_rate); + EXPECT_EQ(2, as.stream.channels); + + EXPECT_THAT(as.stream.aes_key, + ElementsAre(0x51, 0x02, 0x7e, 0x4e, 0x23, 0x47, 0xcb, 0xcb, 0x49, + 0xd5, 0x7e, 0xf1, 0x01, 0x77, 0xae, 0xbc)); + EXPECT_THAT(as.stream.aes_iv_mask, + ElementsAre(0x7f, 0x12, 0xa1, 0x9b, 0xe6, 0x2a, 0x36, 0xc0, 0x4a, + 0xe4, 0x11, 0x6c, 0xaa, 0xef, 0xf6, 0xd1)); +} + +} // namespace + +TEST(OfferTest, ErrorOnEmptyOffer) { + ExpectFailureOnParse("{}"); +} + +TEST(OfferTest, ErrorOnMissingMandatoryFields) { + // It's okay if castMode is omitted, but if supportedStreams isanne // + // omitted we should fail here. + ExpectFailureOnParse(R"({ + "castMode": "mirroring" + })"); +} + +TEST(OfferTest, CanParseValidButStreamlessOffer) { + ErrorOr<Json::Value> root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [] + })"); + ASSERT_TRUE(root.is_value()); + EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); +} + +TEST(OfferTest, ErrorOnMissingAudioStreamMandatoryField) { + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 19088743, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2 + }]})"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "audio_source", + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2 + }]})"); +} + +TEST(OfferTest, CanParseValidButMinimalAudioOffer) { + ErrorOr<Json::Value> root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "audio_source", + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 19088743, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }] + })"); + ASSERT_TRUE(root.is_value()); + EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); +} + +TEST(OfferTest, CanParseValidZeroBitRateAudioOffer) { + ErrorOr<Json::Value> root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "audio_source", + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 19088743, + "bitRate": 0, + "timeBase": "1/96000", + "channels": 5, + "aesKey": "51029e4e2347cbcb49d57ef10177aebd", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff5d2" + }] + })"); + ASSERT_TRUE(root.is_value()); + EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); +} + +TEST(OfferTest, ErrorOnMissingVideoStreamMandatoryField) { + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "codecName": "video_source", + "rtpProfile": "h264", + "rtpPayloadType": 101, + "ssrc": 19088743, + "bitRate": 124000, + "timeBase": "1/48000" + }] + })"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "h264", + "rtpProfile": "cast", + "rtpPayloadType": 101, + "bitRate": 124000, + "timeBase": "1/48000", + "maxBitRate": 10000 + }] + })"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000 + }] + })"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc" + }] + })"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000, + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }] + })"); +} + +TEST(OfferTest, CanParseValidButMinimalVideoOffer) { + ErrorOr<Json::Value> root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }] + })"); + + ASSERT_TRUE(root.is_value()); + EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); +} + +TEST(OfferTest, CanParseValidOffer) { + ErrorOr<Json::Value> root = json::Parse(kValidOffer); + ASSERT_TRUE(root.is_value()); + ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); + + ExpectEqualsValidOffer(offer.value()); +} + +TEST(OfferTest, ParseAndToJsonResultsInSameOffer) { + ErrorOr<Json::Value> root = json::Parse(kValidOffer); + ASSERT_TRUE(root.is_value()); + ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); + + ExpectEqualsValidOffer(offer.value()); + + auto eoj = offer.value().ToJson(); + EXPECT_TRUE(eoj.is_value()); + ErrorOr<Offer> reparsed_offer = Offer::Parse(std::move(eoj.value())); + ExpectEqualsValidOffer(reparsed_offer.value()); +} + +// We don't want to enforce that a given offer must have both audio and +// video, so we don't assert on either. +TEST(OfferTest, ToJsonSucceedsWithMissingStreams) { + ErrorOr<Json::Value> root = json::Parse(kValidOffer); + ASSERT_TRUE(root.is_value()); + ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); + ExpectEqualsValidOffer(offer.value()); + const Offer valid_offer = std::move(offer.value()); + + Offer missing_audio_streams = valid_offer; + missing_audio_streams.audio_streams.clear(); + EXPECT_TRUE(missing_audio_streams.ToJson().is_value()); + + Offer missing_video_streams = valid_offer; + missing_video_streams.audio_streams.clear(); + EXPECT_TRUE(missing_video_streams.ToJson().is_value()); +} + +TEST(OfferTest, ToJsonFailsWithInvalidStreams) { + ErrorOr<Json::Value> root = json::Parse(kValidOffer); + ASSERT_TRUE(root.is_value()); + ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); + ExpectEqualsValidOffer(offer.value()); + const Offer valid_offer = std::move(offer.value()); + + Offer video_stream_invalid = valid_offer; + video_stream_invalid.video_streams[0].max_frame_rate.denominator = 0; + EXPECT_TRUE(video_stream_invalid.ToJson().is_error()); + + Offer audio_stream_invalid = valid_offer; + video_stream_invalid.audio_streams[0].bit_rate = 0; + EXPECT_TRUE(video_stream_invalid.ToJson().is_error()); + + Offer stream_invalid = valid_offer; + stream_invalid.video_streams[0].stream.codec_name = ""; + EXPECT_TRUE(stream_invalid.ToJson().is_error()); +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc index 0d7f268e724..23a5e715841 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc @@ -6,10 +6,8 @@ #include <algorithm> -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { PacketReceiveStatsTracker::PacketReceiveStatsTracker(int rtp_timebase) : rtp_timebase_(rtp_timebase) {} @@ -77,5 +75,5 @@ void PacketReceiveStatsTracker::PopulateNextReport(RtcpReportBlock* report) { report->jitter = RtpTimeDelta::FromDuration(jitter_, rtp_timebase_); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h index e579f76feeb..8d8e23c9fe2 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h @@ -12,8 +12,8 @@ #include "cast/streaming/rtp_time.h" #include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { // Maintains statistics for RTP packet arrival timing, jitter, and loss rates; // and then uses these to compute and set the related fields in a RTCP Receiver @@ -29,10 +29,9 @@ class PacketReceiveStatsTracker { // RtpPacketParser::ParseResult. |arrival_time| is when the packet was // received (i.e., right-off the network socket, before any // processing/parsing). - void OnReceivedValidRtpPacket( - uint16_t sequence_number, - RtpTimeTicks rtp_timestamp, - openscreen::platform::Clock::time_point arrival_time); + void OnReceivedValidRtpPacket(uint16_t sequence_number, + RtpTimeTicks rtp_timestamp, + Clock::time_point arrival_time); // Populates *only* those fields in the given |report| that pertain to packet // loss, jitter, and the latest-known RTP packet sequence number. @@ -89,7 +88,7 @@ class PacketReceiveStatsTracker { // The time the last RTP packet was received. This is used in the computation // that updates |jitter_|. - openscreen::platform::Clock::time_point last_rtp_packet_arrival_time_; + Clock::time_point last_rtp_packet_arrival_time_; // The RTP timestamp of the last RTP packet received. This is used in the // computation that updates |jitter_|. @@ -98,10 +97,10 @@ class PacketReceiveStatsTracker { // The interarrival jitter. See RFC 3550 spec, section 6.4.1. The Cast // Streaming spec diverges from the algorithm in the RFC spec in that it uses // different pieces of timing data to calculate this metric. - openscreen::platform::Clock::duration jitter_; + Clock::duration jitter_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_PACKET_RECEIVE_STATS_TRACKER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc index 1532741f953..5146cc8d302 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc @@ -9,10 +9,8 @@ #include "cast/streaming/constants.h" #include "gtest/gtest.h" -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { namespace { constexpr int kSomeRtpTimebase = static_cast<int>(kVideoTimebase::den); @@ -206,5 +204,5 @@ TEST(PacketReceiveStatsTrackerTest, ComputesJitterCorrectly) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc index cf43ac14142..7e789808e0f 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc @@ -7,10 +7,8 @@ #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtp_defines.h" -using openscreen::ReadBigEndian; - +namespace openscreen { namespace cast { -namespace streaming { std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting( absl::Span<const uint8_t> packet) { @@ -40,5 +38,5 @@ std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting( return std::make_pair(ApparentPacketType::UNKNOWN, Ssrc{0}); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util.h b/chromium/third_party/openscreen/src/cast/streaming/packet_util.h index 869b917eebf..ba84ff719e0 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_util.h +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util.h @@ -11,14 +11,14 @@ #include "cast/streaming/ssrc.h" #include "util/big_endian.h" +namespace openscreen { namespace cast { -namespace streaming { // Reads a field from the start of the given span and advances the span to point // just after the field. template <typename Integer> inline Integer ConsumeField(absl::Span<const uint8_t>* in) { - const Integer result = openscreen::ReadBigEndian<Integer>(in->data()); + const Integer result = ReadBigEndian<Integer>(in->data()); in->remove_prefix(sizeof(Integer)); return result; } @@ -27,7 +27,7 @@ inline Integer ConsumeField(absl::Span<const uint8_t>* in) { // just after the field. template <typename Integer> inline void AppendField(Integer value, absl::Span<uint8_t>* out) { - openscreen::WriteBigEndian<Integer>(value, out->data()); + WriteBigEndian<Integer>(value, out->data()); out->remove_prefix(sizeof(Integer)); } @@ -56,7 +56,7 @@ enum class ApparentPacketType { UNKNOWN, RTP, RTCP }; std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting( absl::Span<const uint8_t> packet); -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_PACKET_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc index 1f1f095d76c..82e2a40871b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc @@ -7,8 +7,8 @@ #include "absl/types/span.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { // Tests that a simple RTCP packet containing only a Sender Report can be @@ -181,5 +181,5 @@ TEST(PacketUtilTest, InspectsGarbagePacket) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc index 73a37eb24e3..b24b61eec3d 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc @@ -5,22 +5,20 @@ #include "cast/streaming/receiver.h" #include <algorithm> -#include <functional> #include "absl/types/span.h" #include "cast/streaming/constants.h" #include "cast/streaming/receiver_packet_router.h" +#include "cast/streaming/session_config.h" #include "util/logging.h" #include "util/std_util.h" -using openscreen::platform::Clock; - using std::chrono::duration_cast; using std::chrono::microseconds; using std::chrono::milliseconds; +namespace openscreen { namespace cast { -namespace streaming { // Conveniences for ensuring logging output includes the SSRC of the Receiver, // to help distinguish one out of multiple instances in a Cast Streaming @@ -33,8 +31,7 @@ namespace streaming { Receiver::Receiver(Environment* environment, ReceiverPacketRouter* packet_router, - const cast::streaming::SessionConfig& config, - std::chrono::milliseconds initial_target_playout_delay) + const SessionConfig& config) : now_(environment->now_function()), packet_router_(packet_router), rtcp_session_(config.sender_ssrc, config.receiver_ssrc, now_()), @@ -51,13 +48,13 @@ Receiver::Receiver(Environment* environment, consumption_alarm_(environment->now_function(), environment->task_runner()) { OSP_DCHECK(packet_router_); - OSP_DCHECK_EQ(checkpoint_frame(), FrameId::first() - 1); + OSP_DCHECK_EQ(checkpoint_frame(), FrameId::leader()); OSP_CHECK_GT(rtcp_buffer_capacity_, 0); OSP_CHECK(rtcp_buffer_); - rtcp_builder_.SetPlayoutDelay(initial_target_playout_delay); - playout_delay_changes_.emplace_back(FrameId::first() - 1, - initial_target_playout_delay); + rtcp_builder_.SetPlayoutDelay(config.target_playout_delay); + playout_delay_changes_.emplace_back(FrameId::leader(), + config.target_playout_delay); packet_router_->OnReceiverCreated(rtcp_session_.sender_ssrc(), this); } @@ -355,8 +352,7 @@ void Receiver::SendRtcp() { // When there are no incomplete frames, use a longer "keepalive" interval. const Clock::duration interval = (no_nacks ? kRtcpReportInterval : kNackFeedbackInterval); - rtcp_alarm_.Schedule(std::bind(&Receiver::SendRtcp, this), - last_rtcp_send_time_ + interval); + rtcp_alarm_.Schedule([this] { SendRtcp(); }, last_rtcp_send_time_ + interval); } const Receiver::PendingFrame& Receiver::GetQueueEntry(FrameId frame_id) const { @@ -389,7 +385,7 @@ void Receiver::RecordNewTargetPlayoutDelay(FrameId as_of_frame, [&](const auto& entry) { return entry.first > as_of_frame; }); playout_delay_changes_.emplace(insert_it, as_of_frame, delay); - OSP_DCHECK(openscreen::AreElementsSortedAndUnique(playout_delay_changes_)); + OSP_DCHECK(AreElementsSortedAndUnique(playout_delay_changes_)); } milliseconds Receiver::ResolveTargetPlayoutDelay(FrameId frame_id) const { @@ -481,5 +477,5 @@ constexpr milliseconds Receiver::kDefaultPlayerProcessingTime; constexpr int Receiver::kNoFramesReady; constexpr milliseconds Receiver::kNackFeedbackInterval; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.h b/chromium/third_party/openscreen/src/cast/streaming/receiver.h index d1ea8b0c08a..e63a9e4ee1f 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver.h +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.h @@ -24,16 +24,16 @@ #include "cast/streaming/rtcp_session.h" #include "cast/streaming/rtp_packet_parser.h" #include "cast/streaming/sender_report_parser.h" -#include "cast/streaming/session_config.h" #include "cast/streaming/ssrc.h" #include "platform/api/time.h" #include "util/alarm.h" +namespace openscreen { namespace cast { -namespace streaming { struct EncodedFrame; class ReceiverPacketRouter; +struct SessionConfig; // The Cast Streaming Receiver, a peer corresponding to some Cast Streaming // Sender at the other end of a network link. @@ -56,7 +56,7 @@ class ReceiverPacketRouter; // decoder, and the resulting decoded media is played out. Also, here is a // general usage example: // -// class MyPlayer : public cast::streaming::Receiver::Consumer { +// class MyPlayer : public openscreen::cast::Receiver::Consumer { // public: // explicit MyPlayer(Receiver* receiver) : receiver_(receiver) { // recevier_->SetPlayerProcessingTime(std::chrono::milliseconds(10)); @@ -72,7 +72,7 @@ class ReceiverPacketRouter; // void OnFramesReady(int next_frame_buffer_size) override { // std::vector<uint8_t> buffer; // buffer.resize(next_frame_buffer_size); -// cast::streaming::EncodedFrame encoded_frame = +// openscreen::cast::EncodedFrame encoded_frame = // receiver_->ConsumeNextFrame(absl::Span<uint8_t>(buffer)); // // display_.RenderFrame(decoder_.DecodeFrame(encoded_frame.data)); @@ -124,8 +124,7 @@ class Receiver { // is started). Receiver(Environment* environment, ReceiverPacketRouter* packet_router, - const cast::streaming::SessionConfig& config, - std::chrono::milliseconds initial_target_playout_delay); + const SessionConfig& config); ~Receiver(); Ssrc ssrc() const { return rtcp_session_.receiver_ssrc(); } @@ -143,8 +142,7 @@ class Receiver { // based on changing environmental conditions. // // Default setting: kDefaultPlayerProcessingTime - void SetPlayerProcessingTime( - openscreen::platform::Clock::duration needed_time); + void SetPlayerProcessingTime(Clock::duration needed_time); // Propagates a "picture loss indicator" notification to the Sender, // requesting a key frame so that decode/playout can recover. It is safe to @@ -185,11 +183,10 @@ class Receiver { // Called by ReceiverPacketRouter to provide this Receiver with what looks // like a RTP/RTCP packet meant for it specifically (among other Receivers). - void OnReceivedRtpPacket(openscreen::platform::Clock::time_point arrival_time, + void OnReceivedRtpPacket(Clock::time_point arrival_time, std::vector<uint8_t> packet); - void OnReceivedRtcpPacket( - openscreen::platform::Clock::time_point arrival_time, - std::vector<uint8_t> packet); + void OnReceivedRtcpPacket(Clock::time_point arrival_time, + std::vector<uint8_t> packet); private: // An entry in the circular queue (see |pending_frames_|). @@ -200,8 +197,7 @@ class Receiver { // at the Sender. This is computed and assigned when the RTP packet with ID // 0 is processed. Add the target playout delay to this to get the target // playout time. - absl::optional<openscreen::platform::Clock::time_point> - estimated_capture_time; + absl::optional<Clock::time_point> estimated_capture_time; PendingFrame(); ~PendingFrame(); @@ -256,10 +252,9 @@ class Receiver { // Sets the |consumption_alarm_| to check whether any frames are ready, // including possibly skipping over late frames in order to make not-yet-late // frames become ready. The default argument value means "without delay." - void ScheduleFrameReadyCheck( - openscreen::platform::Clock::time_point when = {}); + void ScheduleFrameReadyCheck(Clock::time_point when = Alarm::kImmediately); - const openscreen::platform::ClockNowFunctionPtr now_; + const ClockNowFunctionPtr now_; ReceiverPacketRouter* const packet_router_; RtcpSession rtcp_session_; SenderReportParser rtcp_parser_; @@ -276,9 +271,8 @@ class Receiver { // Schedules tasks to ensure RTCP reports are sent within a bounded interval. // Not scheduled until after this Receiver has processed the first packet from // the Sender. - openscreen::Alarm rtcp_alarm_; - openscreen::platform::Clock::time_point last_rtcp_send_time_ = - openscreen::platform::Clock::time_point::min(); + Alarm rtcp_alarm_; + Clock::time_point last_rtcp_send_time_ = Clock::time_point::min(); // The last Sender Report received and when the packet containing it had // arrived. This contains lip-sync timestamps used as part of the calculation @@ -286,7 +280,7 @@ class Receiver { // back to the Sender in the Receiver Reports. It is nullopt until the first // parseable Sender Report is received. absl::optional<SenderReportParser::SenderReportWithId> last_sender_report_; - openscreen::platform::Clock::time_point last_sender_report_arrival_time_; + Clock::time_point last_sender_report_arrival_time_; // Tracks the offset between the Receiver's [local] clock and the Sender's // clock. This is invalid until the first Sender Report has been successfully @@ -295,12 +289,12 @@ class Receiver { // The ID of the latest frame whose existence is known to this Receiver. This // value must always be greater than or equal to |checkpoint_frame()|. - FrameId latest_frame_expected_ = FrameId::first() - 1; + FrameId latest_frame_expected_ = FrameId::leader(); // The ID of the last frame consumed. This value must always be less than or // equal to |checkpoint_frame()|, since it's impossible to consume incomplete // frames! - FrameId last_frame_consumed_ = FrameId::first() - 1; + FrameId last_frame_consumed_ = FrameId::leader(); // The ID of the latest key frame known to be in-flight. This is used by // RequestKeyFrame() to ensure the PLI condition doesn't get set again until @@ -333,19 +327,21 @@ class Receiver { // The additional time needed to decode/play-out each frame after being // consumed from this Receiver. - openscreen::platform::Clock::duration player_processing_time_ = - kDefaultPlayerProcessingTime; + Clock::duration player_processing_time_ = kDefaultPlayerProcessingTime; // Scheduled to check whether there are frames ready and, if there are, to // notify the Consumer via OnFramesReady(). - openscreen::Alarm consumption_alarm_; + Alarm consumption_alarm_; // The interval between sending ACK/NACK feedback RTCP messages while // incomplete frames exist in the queue. + // + // TODO(miu): This should be a function of the current target playout delay, + // similar to the Sender's kickstart interval logic. static constexpr std::chrono::milliseconds kNackFeedbackInterval{30}; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RECEIVER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc index 5e3d17bd229..7a5c41aacc3 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc @@ -5,17 +5,14 @@ #include "cast/streaming/receiver_packet_router.h" #include <algorithm> -#include <iomanip> #include "cast/streaming/packet_util.h" #include "cast/streaming/receiver.h" #include "util/logging.h" +#include "util/stringprintf.h" -using openscreen::IPEndpoint; -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { ReceiverPacketRouter::ReceiverPacketRouter(Environment* environment) : environment_(environment) { @@ -78,16 +75,11 @@ void ReceiverPacketRouter::OnReceivedPacket(const IPEndpoint& source, const std::pair<ApparentPacketType, Ssrc> seems_like = InspectPacketForRouting(packet); if (seems_like.first == ApparentPacketType::UNKNOWN) { - // If the packet type is unknown, log a warning containing a hex dump. - constexpr int kMaxDumpSize = 96; - std::ostringstream hex_dump; - hex_dump << std::setfill('0') << std::hex; - for (int i = 0, len = std::min<int>(packet.size(), kMaxDumpSize); i < len; - ++i) { - hex_dump << std::setw(2) << static_cast<int>(packet[i]); - } + constexpr int kMaxPartiaHexDumpSize = 96; OSP_LOG_WARN << "UNKNOWN packet of " << packet.size() - << " bytes. Partial hex dump: " << hex_dump.str(); + << " bytes. Partial hex dump: " + << HexEncode(absl::Span<const uint8_t>(packet).subspan( + 0, kMaxPartiaHexDumpSize)); return; } const auto it = FindEntry(seems_like.second); @@ -115,5 +107,5 @@ ReceiverPacketRouter::ReceiverEntries::iterator ReceiverPacketRouter::FindEntry( }); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h index 88db08a817a..d90e533596b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h @@ -14,8 +14,8 @@ #include "cast/streaming/environment.h" #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { class Receiver; @@ -47,8 +47,8 @@ class ReceiverPacketRouter final : public Environment::PacketConsumer { using ReceiverEntries = std::vector<std::pair<Ssrc, Receiver*>>; // Environment::PacketConsumer implementation. - void OnReceivedPacket(const openscreen::IPEndpoint& source, - openscreen::platform::Clock::time_point arrival_time, + void OnReceivedPacket(const IPEndpoint& source, + Clock::time_point arrival_time, std::vector<uint8_t> packet) final; // Helper to return an iterator pointing to the entry corresponding to the @@ -60,7 +60,7 @@ class ReceiverPacketRouter final : public Environment::PacketConsumer { ReceiverEntries receivers_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RECEIVER_PACKET_ROUTER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc index 4ad3a3b675c..325a44509a0 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc @@ -4,43 +4,299 @@ #include "cast/streaming/receiver_session.h" +#include <chrono> // NOLINT +#include <string> #include <utility> +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/message_port.h" +#include "cast/streaming/message_util.h" +#include "cast/streaming/offer_messages.h" +#include "cast/streaming/receiver.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { - -ReceiverSession::ConfiguredReceivers::ConfiguredReceivers( - std::unique_ptr<Receiver> audio_receiver, - absl::optional<SessionConfig> audio_receiver_config, - std::unique_ptr<Receiver> video_receiver, - absl::optional<SessionConfig> video_receiver_config) - : audio_receiver_(std::move(audio_receiver)), - audio_receiver_config_(std::move(audio_receiver_config)), - video_receiver_(std::move(video_receiver)), - video_receiver_config_(std::move(video_receiver_config)) {} - -ReceiverSession::ConfiguredReceivers::ConfiguredReceivers( - ConfiguredReceivers&&) noexcept = default; -ReceiverSession::ConfiguredReceivers& ReceiverSession::ConfiguredReceivers:: -operator=(ConfiguredReceivers&&) noexcept = default; -ReceiverSession::ConfiguredReceivers::~ConfiguredReceivers() = default; - -ReceiverSession::ReceiverSession(Client* client, ReceiverPacketRouter* router) - : client_(client), router_(router) { + +// JSON message field values specific to the Receiver Session. +static constexpr char kMessageTypeOffer[] = "OFFER"; + +// List of OFFER message fields. +static constexpr char kOfferMessageBody[] = "offer"; +static constexpr char kKeyType[] = "type"; +static constexpr char kSequenceNumber[] = "seqNum"; + +// Using statements for constructor readability. +using Preferences = ReceiverSession::Preferences; +using ConfiguredReceivers = ReceiverSession::ConfiguredReceivers; + +namespace { + +std::string GetCodecName(ReceiverSession::AudioCodec codec) { + switch (codec) { + case ReceiverSession::AudioCodec::kAac: + return "aac"; + case ReceiverSession::AudioCodec::kOpus: + return "opus"; + } + + OSP_NOTREACHED() << "Codec not accounted for in switch statement."; + return {}; +} + +std::string GetCodecName(ReceiverSession::VideoCodec codec) { + switch (codec) { + case ReceiverSession::VideoCodec::kH264: + return "h264"; + case ReceiverSession::VideoCodec::kVp8: + return "vp8"; + case ReceiverSession::VideoCodec::kHevc: + return "hevc"; + case ReceiverSession::VideoCodec::kVp9: + return "vp9"; + } + + OSP_NOTREACHED() << "Codec not accounted for in switch statement."; + return {}; +} + +template <typename Stream, typename Codec> +const Stream* SelectStream(const std::vector<Codec>& preferred_codecs, + const std::vector<Stream>& offered_streams) { + for (Codec codec : preferred_codecs) { + const std::string codec_name = GetCodecName(codec); + for (const Stream& offered_stream : offered_streams) { + if (offered_stream.stream.codec_name == codec_name) { + OSP_VLOG << "Selected " << codec_name << " as codec for streaming."; + return &offered_stream; + } + } + } + return nullptr; +} + +} // namespace + +Preferences::Preferences() = default; +Preferences::Preferences(std::vector<VideoCodec> video_codecs, + std::vector<AudioCodec> audio_codecs) + : Preferences(video_codecs, audio_codecs, nullptr, nullptr) {} + +Preferences::Preferences(std::vector<VideoCodec> video_codecs, + std::vector<AudioCodec> audio_codecs, + std::unique_ptr<Constraints> constraints, + std::unique_ptr<DisplayDescription> description) + : video_codecs(std::move(video_codecs)), + audio_codecs(std::move(audio_codecs)), + constraints(std::move(constraints)), + display_description(std::move(description)) {} + +Preferences::Preferences(Preferences&&) noexcept = default; +Preferences& Preferences::operator=(Preferences&&) noexcept = default; + +ReceiverSession::ReceiverSession(Client* const client, + Environment* environment, + MessagePort* message_port, + Preferences preferences) + : client_(client), + environment_(environment), + message_port_(message_port), + preferences_(std::move(preferences)), + packet_router_(environment_) { OSP_DCHECK(client_); - OSP_DCHECK(router_); + OSP_DCHECK(message_port_); + OSP_DCHECK(environment_); + + message_port_->SetClient(this); +} + +ReceiverSession::~ReceiverSession() { + ResetReceivers(); + message_port_->SetClient(nullptr); +} + +void ReceiverSession::OnMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) { + ErrorOr<Json::Value> message_json = json::Parse(message); + + if (!message_json) { + client_->OnError(this, Error::Code::kJsonParseError); + OSP_LOG_WARN << "Received an invalid message: " << message; + return; + } + + // TODO(jophba): add sender connected/disconnected messaging. + auto sequence_number = ParseInt(message_json.value(), kSequenceNumber); + if (!sequence_number) { + OSP_LOG_WARN << "Invalid message sequence number"; + return; + } + + auto key_or_error = ParseString(message_json.value(), kKeyType); + if (!key_or_error) { + OSP_LOG_WARN << "Invalid message key"; + return; + } + + Message parsed_message{sender_id.data(), message_namespace.data(), + sequence_number.value()}; + if (key_or_error.value() == kMessageTypeOffer) { + parsed_message.body = std::move(message_json.value()[kOfferMessageBody]); + if (parsed_message.body.isNull()) { + OSP_LOG_WARN << "Invalid message offer body"; + return; + } + OnOffer(&parsed_message); + } +} + +void ReceiverSession::OnError(Error error) { + OSP_LOG_WARN << "ReceiverSession's MessagePump encountered an error:" + << error; +} + +void ReceiverSession::OnOffer(Message* message) { + ErrorOr<Offer> offer = Offer::Parse(std::move(message->body)); + if (!offer) { + client_->OnError(this, offer.error()); + OSP_LOG_WARN << "Could not parse offer" << offer.error(); + return; + } + + const AudioStream* selected_audio_stream = nullptr; + if (!offer.value().audio_streams.empty() && + !preferences_.audio_codecs.empty()) { + selected_audio_stream = + SelectStream(preferences_.audio_codecs, offer.value().audio_streams); + } + + const VideoStream* selected_video_stream = nullptr; + if (!offer.value().video_streams.empty() && + !preferences_.video_codecs.empty()) { + selected_video_stream = + SelectStream(preferences_.video_codecs, offer.value().video_streams); + } + + cast_mode_ = offer.value().cast_mode; + auto receivers = + TrySpawningReceivers(selected_audio_stream, selected_video_stream); + if (receivers) { + const Answer answer = + ConstructAnswer(message, selected_audio_stream, selected_video_stream); + client_->OnNegotiated(this, std::move(receivers.value())); + + message->body = answer.ToAnswerMessage(); + } else { + message->body = CreateInvalidAnswer(receivers.error()); + } + + SendMessage(message); +} + +std::pair<SessionConfig, std::unique_ptr<Receiver>> +ReceiverSession::ConstructReceiver(const Stream& stream) { + SessionConfig config = {stream.ssrc, stream.ssrc + 1, + stream.rtp_timebase, stream.channels, + stream.target_delay, stream.aes_key, + stream.aes_iv_mask}; + auto receiver = + std::make_unique<Receiver>(environment_, &packet_router_, config); + + return std::make_pair(std::move(config), std::move(receiver)); +} + +ErrorOr<ConfiguredReceivers> ReceiverSession::TrySpawningReceivers( + const AudioStream* audio, + const VideoStream* video) { + if (!audio && !video) { + return Error::Code::kParameterInvalid; + } + + ResetReceivers(); + + absl::optional<ConfiguredReceiver<AudioStream>> audio_receiver; + absl::optional<ConfiguredReceiver<VideoStream>> video_receiver; + + if (audio) { + auto audio_pair = ConstructReceiver(audio->stream); + current_audio_receiver_ = std::move(audio_pair.second); + audio_receiver.emplace(ConfiguredReceiver<AudioStream>{ + current_audio_receiver_.get(), std::move(audio_pair.first), *audio}); + } + + if (video) { + auto video_pair = ConstructReceiver(video->stream); + current_video_receiver_ = std::move(video_pair.second); + video_receiver.emplace(ConfiguredReceiver<VideoStream>{ + current_video_receiver_.get(), std::move(video_pair.first), *video}); + } + + return ConfiguredReceivers{std::move(audio_receiver), + std::move(video_receiver)}; +} + +void ReceiverSession::ResetReceivers() { + if (current_video_receiver_ || current_audio_receiver_) { + client_->OnConfiguredReceiversDestroyed(this); + current_audio_receiver_.reset(); + current_video_receiver_.reset(); + } +} + +Answer ReceiverSession::ConstructAnswer( + Message* message, + const AudioStream* selected_audio_stream, + const VideoStream* selected_video_stream) { + OSP_DCHECK(selected_audio_stream || selected_video_stream); + + std::vector<int> stream_indexes; + std::vector<Ssrc> stream_ssrcs; + if (selected_audio_stream) { + stream_indexes.push_back(selected_audio_stream->stream.index); + stream_ssrcs.push_back(selected_audio_stream->stream.ssrc + 1); + } + + if (selected_video_stream) { + stream_indexes.push_back(selected_video_stream->stream.index); + stream_ssrcs.push_back(selected_video_stream->stream.ssrc + 1); + } + + absl::optional<Constraints> constraints; + if (preferences_.constraints) { + constraints = *preferences_.constraints; + } + + absl::optional<DisplayDescription> display; + if (preferences_.display_description) { + display = *preferences_.display_description; + } + + return Answer{cast_mode_, + environment_->GetBoundLocalEndpoint().port, + std::move(stream_indexes), + std::move(stream_ssrcs), + constraints, + display, + std::vector<int>{}, // receiver_rtcp_event_log + std::vector<int>{}, // receiver_rtcp_dscp + supports_wifi_status_reporting_}; } -ReceiverSession::ReceiverSession(ReceiverSession&&) noexcept = default; -ReceiverSession& ReceiverSession::operator=(ReceiverSession&&) = default; -ReceiverSession::~ReceiverSession() = default; +void ReceiverSession::SendMessage(Message* message) { + // All messages have the sequence number embedded. + message->body[kSequenceNumber] = message->sequence_number; -void ReceiverSession::SelectOffer(const SessionConfig& selected_offer) { - // TODO(jophba): implement receiver session methods. - OSP_UNIMPLEMENTED(); + auto body_or_error = json::Stringify(message->body); + if (body_or_error.is_value()) { + message_port_->PostMessage(message->sender_id, message->message_namespace, + body_or_error.value()); + } else { + client_->OnError(this, body_or_error.error()); + } } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h index 248bf8ad178..be59f4ec2c7 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h @@ -6,73 +6,161 @@ #define CAST_STREAMING_RECEIVER_SESSION_H_ #include <memory> +#include <string> #include <vector> -#include "cast/streaming/receiver.h" +#include "cast/streaming/answer_messages.h" +#include "cast/streaming/message_port.h" +#include "cast/streaming/offer_messages.h" #include "cast/streaming/receiver_packet_router.h" #include "cast/streaming/session_config.h" +#include "util/json/json_serialization.h" +namespace openscreen { namespace cast { -namespace streaming { -class ReceiverSession { +class CastSocket; +class Environment; +class Receiver; +class VirtualConnectionRouter; +class VirtualConnection; + +class ReceiverSession final : public MessagePort::Client { public: - class ConfiguredReceivers { - public: + // A small helper struct that contains all of the information necessary for + // a configured receiver, including a receiver, its session config, and the + // stream selected from the OFFER message to instantiate the receiver. + template <typename T> + struct ConfiguredReceiver { + Receiver* receiver; + const SessionConfig receiver_config; + const T selected_stream; + }; + + // Upon successful negotiation, a set of configured receivers is constructed + // for handling audio and video. Note that either receiver may be null. + struct ConfiguredReceivers { // In practice, we may have 0, 1, or 2 receivers configured, depending // on if the device supports audio and video, and if we were able to // successfully negotiate a receiver configuration. - ConfiguredReceivers( - std::unique_ptr<Receiver> audio_receiver, - const absl::optional<SessionConfig> audio_receiver_config, - std::unique_ptr<Receiver> video_receiver, - const absl::optional<SessionConfig> video_receiver_config); - ConfiguredReceivers(const ConfiguredReceivers&) = delete; - ConfiguredReceivers(ConfiguredReceivers&&) noexcept; - ConfiguredReceivers& operator=(const ConfiguredReceivers&) = delete; - ConfiguredReceivers& operator=(ConfiguredReceivers&&) noexcept; - ~ConfiguredReceivers(); + + // NOTES ON LIFETIMES: The audio and video receiver pointers are expected + // to be valid until the OnConfiguredReceiversDestroyed event is fired, at + // which point they become invalid and need to replaced by the results of + // the ensuing OnNegotiated call. // If the receiver is audio- or video-only, either of the receivers // may be nullptr. However, in the majority of cases they will be populated. - Receiver* audio_receiver() const { return audio_receiver_.get(); } - const absl::optional<SessionConfig>& audio_session_config() const { - return audio_receiver_config_; - } - Receiver* video_receiver() const { return video_receiver_.get(); } - const absl::optional<SessionConfig>& video_session_config() const { - return video_receiver_config_; - } - - private: - std::unique_ptr<Receiver> audio_receiver_; - absl::optional<SessionConfig> audio_receiver_config_; - std::unique_ptr<Receiver> video_receiver_; - absl::optional<SessionConfig> video_receiver_config_; + absl::optional<ConfiguredReceiver<AudioStream>> audio; + absl::optional<ConfiguredReceiver<VideoStream>> video; }; + // The embedder should provide a client for handling connections. + // When a connection is established, the OnNegotiated callback is called. class Client { public: - virtual ~Client() = 0; - virtual void OnOffer(std::vector<SessionConfig> offers) = 0; - virtual void OnNegotiated(ConfiguredReceivers receivers) = 0; + // This method is called when a new set of receivers has been negotiated. + virtual void OnNegotiated(const ReceiverSession* session, + ConfiguredReceivers receivers) = 0; + + // This method is called immediately preceding the invalidation of + // this session's receivers. + virtual void OnConfiguredReceiversDestroyed( + const ReceiverSession* session) = 0; + + virtual void OnError(const ReceiverSession* session, Error error) = 0; }; - ReceiverSession(Client* client, ReceiverPacketRouter* router); + // The embedder has the option of providing a list of prioritized + // preferences for selecting from the offer. + enum class AudioCodec : int { kAac, kOpus }; + enum class VideoCodec : int { kH264, kVp8, kHevc, kVp9 }; + + // Note: embedders are required to implement the following + // codecs to be Cast V2 compliant: H264, VP8, AAC, Opus. + struct Preferences { + Preferences(); + Preferences(std::vector<VideoCodec> video_codecs, + std::vector<AudioCodec> audio_codecs); + Preferences(std::vector<VideoCodec> video_codecs, + std::vector<AudioCodec> audio_codecs, + std::unique_ptr<Constraints> constraints, + std::unique_ptr<DisplayDescription> description); + + Preferences(Preferences&&) noexcept; + Preferences(const Preferences&) = delete; + Preferences& operator=(Preferences&&) noexcept; + Preferences& operator=(const Preferences&) = delete; + + std::vector<VideoCodec> video_codecs{VideoCodec::kVp8, VideoCodec::kH264}; + std::vector<AudioCodec> audio_codecs{AudioCodec::kOpus, AudioCodec::kAac}; + + // The embedder has the option of directly specifying the display + // information and video/audio constraints that will be passed along to + // senders during the offer/answer exchange. If nullptr, these are ignored. + std::unique_ptr<Constraints> constraints; + std::unique_ptr<DisplayDescription> display_description; + }; + + ReceiverSession(Client* const client, + Environment* environment, + MessagePort* message_port, + Preferences preferences); ReceiverSession(const ReceiverSession&) = delete; - ReceiverSession(ReceiverSession&&) noexcept; + ReceiverSession(ReceiverSession&&) = delete; ReceiverSession& operator=(const ReceiverSession&) = delete; - ReceiverSession& operator=(ReceiverSession&&) noexcept; + ReceiverSession& operator=(ReceiverSession&&) = delete; ~ReceiverSession(); - void SelectOffer(const SessionConfig& selected_offer); + // MessagePort::Client overrides + void OnMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) override; + void OnError(Error error) override; private: - Client* client_; - ReceiverPacketRouter* router_; + struct Message { + const std::string sender_id = {}; + const std::string message_namespace = {}; + const int sequence_number = 0; + Json::Value body; + }; + + // Message handlers + void OnOffer(Message* message); + + std::pair<SessionConfig, std::unique_ptr<Receiver>> ConstructReceiver( + const Stream& stream); + + // Either stream input to this method may be null, however if both + // are null this method returns error. + ErrorOr<ConfiguredReceivers> TrySpawningReceivers(const AudioStream* audio, + const VideoStream* video); + + // Callers of this method should ensure at least one stream is non-null. + Answer ConstructAnswer(Message* message, + const AudioStream* audio, + const VideoStream* video); + + void SendMessage(Message* message); + + // Handles resetting receivers and notifying the client. + void ResetReceivers(); + + Client* const client_; + Environment* const environment_; + MessagePort* const message_port_; + const Preferences preferences_; + + CastMode cast_mode_; + bool supports_wifi_status_reporting_ = false; + ReceiverPacketRouter packet_router_; + + std::unique_ptr<Receiver> current_audio_receiver_; + std::unique_ptr<Receiver> current_video_receiver_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RECEIVER_SESSION_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc new file mode 100644 index 00000000000..ce9a0079cd7 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc @@ -0,0 +1,536 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/receiver_session.h" + +#include <utility> + +#include "cast/streaming/mock_environment.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/base/ip_address.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" + +using ::testing::_; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::StrictMock; + +namespace openscreen { +namespace cast { + +namespace { + +constexpr char kValidOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + { + "index": 31337, + "type": "video_source", + "codecName": "vp9", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088743, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "4", + "aesKey": "bbf109bf84513b456b13a184453b66ce", + "aesIvMask": "edaf9e4536e2b66191f560d9c04b2a69", + "resolutions": [ + { + "width": 1280, + "height": 720 + } + ] + }, + { + "index": 31338, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088745, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "4", + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1280, + "height": 720 + } + ] + }, + { + "index": 1337, + "type": "audio_source", + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 97, + "ssrc": 19088747, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + } + ] + } +})"; + +constexpr char kNoAudioOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + { + "index": 31338, + "type": "video_source", + "codecName": "vp8", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088745, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "4", + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1280, + "height": 720 + } + ] + } + ] + } +})"; + +constexpr char kNoVideoOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + { + "index": 1337, + "type": "audio_source", + "codecName": "opus", + "rtpProfile": "cast", + "rtpPayloadType": 97, + "ssrc": 19088747, + "bitRate": 124000, + "timeBase": "1/48000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + } + ] + } +})"; + +constexpr char kNoAudioOrVideoOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [] + } +})"; + +constexpr char kInvalidJsonOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337,,, + "offer": + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + } +})"; + +class SimpleMessagePort : public MessagePort { + public: + ~SimpleMessagePort() override {} + void SetClient(MessagePort::Client* client) override { client_ = client; } + + void ReceiveMessage(absl::string_view message) { + ASSERT_NE(client_, nullptr); + client_->OnMessage("sender-id", "namespace", message); + } + + void ReceiveError(Error error) { + ASSERT_NE(client_, nullptr); + client_->OnError(error); + } + + void PostMessage(absl::string_view sender_id, + absl::string_view message_namespace, + absl::string_view message) override { + posted_messages_.emplace_back(std::move(message)); + } + + MessagePort::Client* client() const { return client_; } + const std::vector<std::string> posted_messages() const { + return posted_messages_; + } + + private: + MessagePort::Client* client_ = nullptr; + std::vector<std::string> posted_messages_; +}; + +class FakeClient : public ReceiverSession::Client { + public: + MOCK_METHOD(void, + OnNegotiated, + (const ReceiverSession*, ReceiverSession::ConfiguredReceivers), + (override)); + MOCK_METHOD(void, + OnConfiguredReceiversDestroyed, + (const ReceiverSession*), + (override)); + MOCK_METHOD(void, OnError, (const ReceiverSession*, Error error), (override)); +}; + +void ExpectIsErrorAnswerMessage(const ErrorOr<Json::Value>& message_or_error) { + EXPECT_TRUE(message_or_error.is_value()); + const Json::Value message = std::move(message_or_error.value()); + EXPECT_TRUE(message["answer"].isNull()); + EXPECT_EQ("error", message["result"].asString()); + EXPECT_EQ(1337, message["seqNum"].asInt()); + EXPECT_EQ("ANSWER", message["type"].asString()); + + const Json::Value& error = message["error"]; + EXPECT_TRUE(error.isObject()); + EXPECT_GT(error["code"].asInt(), 0); + EXPECT_EQ("", error["description"].asString()); +} + +} // namespace + +class ReceiverSessionTest : public ::testing::Test { + public: + ReceiverSessionTest() : clock_(Clock::time_point{}), task_runner_(&clock_) {} + + std::unique_ptr<MockEnvironment> MakeEnvironment() { + auto environment = std::make_unique<NiceMock<MockEnvironment>>( + &FakeClock::now, &task_runner_); + ON_CALL(*environment, GetBoundLocalEndpoint()) + .WillByDefault( + Return(IPEndpoint{IPAddress::Parse("127.0.0.1").value(), 12345})); + return environment; + } + + private: + FakeClock clock_; + FakeTaskRunner task_runner_; +}; + +TEST_F(ReceiverSessionTest, RegistersSelfOnMessagePump) { + auto message_port = std::make_unique<SimpleMessagePort>(); + // This should be safe, since the message_port location should not move + // just because of being moved into the ReceiverSession. + StrictMock<FakeClient> client; + + auto environment = MakeEnvironment(); + auto session = std::make_unique<ReceiverSession>( + &client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + EXPECT_EQ(message_port->client(), session.get()); +} + +TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + EXPECT_CALL(client, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers cr) { + EXPECT_TRUE(cr.audio); + EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u); + EXPECT_EQ(cr.audio.value().receiver_config.receiver_ssrc, 19088748u); + EXPECT_EQ(cr.audio.value().receiver_config.channels, 2); + EXPECT_EQ(cr.audio.value().receiver_config.rtp_timebase, 48000); + + // We should have chosen opus + EXPECT_EQ(cr.audio.value().selected_stream.stream.index, 1337); + EXPECT_EQ(cr.audio.value().selected_stream.stream.type, + Stream::Type::kAudioSource); + EXPECT_EQ(cr.audio.value().selected_stream.stream.codec_name, "opus"); + EXPECT_EQ(cr.audio.value().selected_stream.stream.channels, 2); + + EXPECT_TRUE(cr.video); + EXPECT_EQ(cr.video.value().receiver_config.sender_ssrc, 19088745u); + EXPECT_EQ(cr.video.value().receiver_config.receiver_ssrc, 19088746u); + EXPECT_EQ(cr.video.value().receiver_config.channels, 1); + EXPECT_EQ(cr.video.value().receiver_config.rtp_timebase, 90000); + + // We should have chosen vp8 + EXPECT_EQ(cr.video.value().selected_stream.stream.index, 31338); + EXPECT_EQ(cr.video.value().selected_stream.stream.type, + Stream::Type::kVideoSource); + EXPECT_EQ(cr.video.value().selected_stream.stream.codec_name, "vp8"); + EXPECT_EQ(cr.video.value().selected_stream.stream.channels, 1); + }); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + + message_port->ReceiveMessage(kValidOfferMessage); + + const auto& messages = message_port->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + const Json::Value answer = std::move(message_body.value()); + + EXPECT_EQ("ANSWER", answer["type"].asString()); + EXPECT_EQ(1337, answer["seqNum"].asInt()); + EXPECT_EQ("ok", answer["result"].asString()); + + const Json::Value& answer_body = answer["answer"]; + EXPECT_TRUE(answer_body.isObject()); + + // Spot check the answer body fields. We have more in depth testing + // of answer behavior in answer_messages_unittest, but here we can + // ensure that the ReceiverSession properly configured the answer. + EXPECT_EQ("mirroring", answer_body["castMode"].asString()); + EXPECT_EQ(1337, answer_body["sendIndexes"][0].asInt()); + EXPECT_EQ(31338, answer_body["sendIndexes"][1].asInt()); + EXPECT_LT(0, answer_body["udpPort"].asInt()); + EXPECT_GT(65535, answer_body["udpPort"].asInt()); + + // Get status should always be false, as we have no plans to implement it. + EXPECT_EQ(false, answer_body["receiverGetStatus"].asBool()); + + // Constraints and display should not be present with no preferences. + EXPECT_TRUE(answer_body["constraints"].isNull()); + EXPECT_TRUE(answer_body["display"].isNull()); +} + +TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session( + &client, environment.get(), message_port.get(), + ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9}, + {ReceiverSession::AudioCodec::kOpus}}); + + EXPECT_CALL(client, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers cr) { + EXPECT_TRUE(cr.audio); + EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u); + EXPECT_EQ(cr.audio.value().receiver_config.receiver_ssrc, 19088748u); + EXPECT_EQ(cr.audio.value().receiver_config.channels, 2); + EXPECT_EQ(cr.audio.value().receiver_config.rtp_timebase, 48000); + + EXPECT_TRUE(cr.video); + // We should have chosen vp9 + EXPECT_EQ(cr.video.value().receiver_config.sender_ssrc, 19088743u); + EXPECT_EQ(cr.video.value().receiver_config.receiver_ssrc, 19088744u); + EXPECT_EQ(cr.video.value().receiver_config.channels, 1); + EXPECT_EQ(cr.video.value().receiver_config.rtp_timebase, 90000); + }); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + message_port->ReceiveMessage(kValidOfferMessage); +} + +TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + + auto constraints = std::unique_ptr<Constraints>{new Constraints{ + AudioConstraints{1, 2, 3, 4}, + VideoConstraints{3.14159, Dimensions{320, 240, SimpleFraction{24, 1}}, + Dimensions{1920, 1080, SimpleFraction{144, 1}}, 3000, + 90000000, std::chrono::milliseconds{1000}}}}; + + auto display = std::unique_ptr<DisplayDescription>{new DisplayDescription{ + Dimensions{640, 480, SimpleFraction{60, 1}}, AspectRatio{16, 9}, + AspectRatioConstraint::kFixed}}; + + auto environment = MakeEnvironment(); + ReceiverSession session( + &client, environment.get(), message_port.get(), + ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9}, + {ReceiverSession::AudioCodec::kOpus}, + std::move(constraints), + std::move(display)}); + + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + message_port->ReceiveMessage(kValidOfferMessage); + + const auto& messages = message_port->posted_messages(); + EXPECT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + ASSERT_TRUE(message_body.is_value()); + const Json::Value answer = std::move(message_body.value()); + + const Json::Value& answer_body = answer["answer"]; + ASSERT_TRUE(answer_body.isObject()); + + // Constraints and display should be valid with valid preferences. + ASSERT_FALSE(answer_body["constraints"].isNull()); + ASSERT_FALSE(answer_body["display"].isNull()); + + const Json::Value& display_json = answer_body["display"]; + EXPECT_EQ("16:9", display_json["aspectRatio"].asString()); + EXPECT_EQ("60", display_json["dimensions"]["frameRate"].asString()); + EXPECT_EQ(640, display_json["dimensions"]["width"].asInt()); + EXPECT_EQ(480, display_json["dimensions"]["height"].asInt()); + EXPECT_EQ("sender", display_json["scaling"].asString()); + + const Json::Value& constraints_json = answer_body["constraints"]; + ASSERT_TRUE(constraints_json.isObject()); + + const Json::Value& audio = constraints_json["audio"]; + ASSERT_TRUE(audio.isObject()); + EXPECT_EQ(4, audio["maxBitRate"].asInt()); + EXPECT_EQ(2, audio["maxChannels"].asInt()); + EXPECT_EQ(0, audio["maxDelay"].asInt()); + EXPECT_EQ(1, audio["maxSampleRate"].asInt()); + EXPECT_EQ(3, audio["minBitRate"].asInt()); + + const Json::Value& video = constraints_json["video"]; + ASSERT_TRUE(video.isObject()); + EXPECT_EQ(90000000, video["maxBitRate"].asInt()); + EXPECT_EQ(1000, video["maxDelay"].asInt()); + EXPECT_EQ("144", video["maxDimensions"]["frameRate"].asString()); + EXPECT_EQ(1920, video["maxDimensions"]["width"].asInt()); + EXPECT_EQ(1080, video["maxDimensions"]["height"].asInt()); + EXPECT_DOUBLE_EQ(3.14159, video["maxPixelsPerSecond"].asDouble()); + EXPECT_EQ(3000, video["minBitRate"].asInt()); + EXPECT_EQ("24", video["minDimensions"]["frameRate"].asString()); + EXPECT_EQ(320, video["minDimensions"]["width"].asInt()); + EXPECT_EQ(240, video["minDimensions"]["height"].asInt()); +} + +TEST_F(ReceiverSessionTest, HandlesNoValidAudioStream) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + + message_port->ReceiveMessage(kNoAudioOfferMessage); + const auto& messages = message_port->posted_messages(); + EXPECT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + const Json::Value& answer_body = message_body.value()["answer"]; + EXPECT_TRUE(answer_body.isObject()); + + // Should still select video stream. + EXPECT_EQ(1u, answer_body["sendIndexes"].size()); + EXPECT_EQ(31338, answer_body["sendIndexes"][0].asInt()); + EXPECT_EQ(1u, answer_body["ssrcs"].size()); + EXPECT_EQ(19088746, answer_body["ssrcs"][0].asInt()); +} + +TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + + message_port->ReceiveMessage(kNoVideoOfferMessage); + const auto& messages = message_port->posted_messages(); + EXPECT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + const Json::Value& answer_body = message_body.value()["answer"]; + EXPECT_TRUE(answer_body.isObject()); + + // Should still select audio stream. + EXPECT_EQ(1u, answer_body["sendIndexes"].size()); + EXPECT_EQ(1337, answer_body["sendIndexes"][0].asInt()); + EXPECT_EQ(1u, answer_body["ssrcs"].size()); + EXPECT_EQ(19088748, answer_body["ssrcs"][0].asInt()); +} + +TEST_F(ReceiverSessionTest, HandlesNoValidStreams) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + // We shouldn't call OnNegotiated if we failed to negotiate any streams. + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0); + + message_port->ReceiveMessage(kNoAudioOrVideoOfferMessage); + const auto& messages = message_port->posted_messages(); + EXPECT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + ExpectIsErrorAnswerMessage(message_body); +} + +TEST_F(ReceiverSessionTest, HandlesMalformedOffer) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + // We shouldn't call OnNegotiated if we failed to negotiate any streams. + // Note that unlike when we simply don't select any streams, when the offer + // is actually completely invalid we call OnError. + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0); + EXPECT_CALL(client, OnError(&session, Error(Error::Code::kJsonParseError))) + .Times(1); + + message_port->ReceiveMessage(kInvalidJsonOfferMessage); +} + +TEST_F(ReceiverSessionTest, NotifiesReceiverDestruction) { + auto message_port = std::make_unique<SimpleMessagePort>(); + StrictMock<FakeClient> client; + auto environment = MakeEnvironment(); + ReceiverSession session(&client, environment.get(), message_port.get(), + ReceiverSession::Preferences{}); + + EXPECT_CALL(client, OnNegotiated(&session, _)).Times(2); + EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(2); + + message_port->ReceiveMessage(kNoAudioOfferMessage); + message_port->ReceiveMessage(kValidOfferMessage); +} +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc index 2426376aed8..826af0b0663 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc @@ -8,6 +8,7 @@ #include <algorithm> #include <array> +#include <utility> #include <vector> #include "absl/types/span.h" @@ -15,6 +16,7 @@ #include "cast/streaming/constants.h" #include "cast/streaming/encoded_frame.h" #include "cast/streaming/frame_crypto.h" +#include "cast/streaming/mock_environment.h" #include "cast/streaming/receiver_packet_router.h" #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtcp_session.h" @@ -35,18 +37,6 @@ #include "platform/test/fake_task_runner.h" #include "util/logging.h" -using openscreen::Error; -using openscreen::ErrorOr; -using openscreen::IPAddress; -using openscreen::IPEndpoint; -using openscreen::platform::Clock; -using openscreen::platform::ClockNowFunctionPtr; -using openscreen::platform::FakeClock; -using openscreen::platform::FakeTaskRunner; -using openscreen::platform::TaskRunner; -using openscreen::platform::UdpPacket; -using openscreen::platform::UdpSocket; - using std::chrono::duration_cast; using std::chrono::microseconds; using std::chrono::milliseconds; @@ -58,8 +48,8 @@ using testing::Gt; using testing::Invoke; using testing::SaveArg; +namespace openscreen { namespace cast { -namespace streaming { namespace { // Receiver configuration. @@ -80,8 +70,12 @@ constexpr milliseconds kTargetPlayoutDelayChange{800}; constexpr RtpPayloadType kRtpPayloadType = RtpPayloadType::kVideoVp8; constexpr int kMaxRtpPacketSize = 64; -// A simulated one-way network delay. -constexpr auto kOneWayNetworkDelay = milliseconds(23); +// A simulated one-way network delay, and round-trip network delay. +constexpr auto kOneWayNetworkDelay = milliseconds(3); +constexpr auto kRoundTripNetworkDelay = 2 * kOneWayNetworkDelay; +static_assert(kRoundTripNetworkDelay < kTargetPlayoutDelay && + kRoundTripNetworkDelay < kTargetPlayoutDelayChange, + "Network delay must be smaller than target playout delay."); // An EncodedFrame for unit testing, one of a sequence of simulated frames, each // of 10 ms duration. The first frame will be a key frame; and any later frames @@ -180,10 +174,11 @@ class MockSender : public CompoundRtcpParser::Client { UdpPacket packet_to_send(packet_and_report_id.first.begin(), packet_and_report_id.first.end()); packet_to_send.set_source(sender_endpoint_); - task_runner_->PostTask( + task_runner_->PostTaskWithDelay( [receiver = receiver_, packet = std::move(packet_to_send)]() mutable { receiver->OnRead(nullptr, ErrorOr<UdpPacket>(std::move(packet))); - }); + }, + kOneWayNetworkDelay); return packet_and_report_id.second; } @@ -224,10 +219,11 @@ class MockSender : public CompoundRtcpParser::Client { rtp_packetizer_.GeneratePacket(frame_being_sent_, packet_id, buffer); UdpPacket packet_to_send(span.begin(), span.end()); packet_to_send.set_source(sender_endpoint_); - task_runner_->PostTask( + task_runner_->PostTaskWithDelay( [receiver = receiver_, packet = std::move(packet_to_send)]() mutable { receiver->OnRead(nullptr, ErrorOr<UdpPacket>(std::move(packet))); - }); + }, + kOneWayNetworkDelay); } } @@ -263,19 +259,6 @@ class MockSender : public CompoundRtcpParser::Client { EncryptedFrame frame_being_sent_; }; -// An Environment that can intercept all packet sends. ReceiverTest will connect -// the SendPacket() method calls to the MockSender. -class MockEnvironment : public Environment { - public: - MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner) - : Environment(now_function, task_runner) {} - - ~MockEnvironment() override = default; - - // Used for intercepting packet sends from the implementation under test. - MOCK_METHOD1(SendPacket, void(absl::Span<const uint8_t> packet)); -}; - class MockConsumer : public Receiver::Consumer { public: MOCK_METHOD1(OnFramesReady, void(int next_frame_buffer_size)); @@ -294,14 +277,21 @@ class ReceiverTest : public testing::Test { /* .receiver_ssrc = */ kReceiverSsrc, /* .rtp_timebase = */ kRtpTimebase, /* .channels = */ 2, + /* .target_playout_delay = */ kTargetPlayoutDelay, /* .aes_secret_key = */ kAesKey, - /* .aes_iv_mask = */ kCastIvMask}, - kTargetPlayoutDelay), + /* .aes_iv_mask = */ kCastIvMask}), sender_(&task_runner_, &env_) { env_.set_socket_error_handler( [](Error error) { ASSERT_TRUE(error.ok()) << error; }); ON_CALL(env_, SendPacket(_)) - .WillByDefault(Invoke(&sender_, &MockSender::OnPacketFromReceiver)); + .WillByDefault(Invoke([this](absl::Span<const uint8_t> packet) { + task_runner_.PostTaskWithDelay( + [sender = &sender_, copy_of_packet = std::vector<uint8_t>( + packet.begin(), packet.end())]() mutable { + sender->OnPacketFromReceiver(std::move(copy_of_packet)); + }, + kOneWayNetworkDelay); + })); receiver_.SetConsumer(&consumer_); } @@ -314,6 +304,22 @@ class ReceiverTest : public testing::Test { void AdvanceClockAndRunTasks(Clock::duration delta) { clock_.Advance(delta); } void RunTasksUntilIdle() { task_runner_.RunTasksUntilIdle(); } + // Sends the initial Sender Report with lip-sync timing information to + // "unblock" the Receiver, and confirms the Receiver immediately replies with + // a corresponding Receiver Report. + void ExchangeInitialReportPackets() { + const Clock::time_point start_time = FakeClock::now(); + sender_.SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); + AdvanceClockAndRunTasks( + kOneWayNetworkDelay); // Transmit report to Receiver. + // The Receiver will immediately reply with a Receiver Report. + EXPECT_CALL(sender_, + OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay)) + .Times(1); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); // Transmit reply to Sender. + testing::Mock::VerifyAndClearExpectations(&sender_); + } + // Consume one frame from the Receiver, and verify that it is the same as the // |sent_frame|. Exception: The |reference_time| is the playout time on the // Receiver's end, while it refers to the capture time on the Sender's end. @@ -371,7 +377,7 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) { EXPECT_CALL(*sender(), OnReceiverReport(_)) .WillOnce(SaveArg<0>(&receiver_report)); EXPECT_CALL(*sender(), - OnReceiverCheckpoint(FrameId::first() - 1, kTargetPlayoutDelay)) + OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay)) .Times(1); // Have the MockSender send a Sender Report with lip-sync timing information. @@ -380,7 +386,8 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) { RtpTimeTicks::FromTimeSinceOrigin(seconds(1), kRtpTimebase); const StatusReportId sender_report_id = sender()->SendSenderReport(sender_reference_time, sender_rtp_timestamp); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); // Expect the MockSender got back a Receiver Report that includes its SSRC and // the last Sender Report ID. @@ -389,7 +396,7 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) { EXPECT_EQ(sender_report_id, receiver_report.last_status_report_id); // Confirm the clock offset math: Since the Receiver and MockSender share the - // same underlying FakeClock, the Receiver should be 10ms ahead of the Sender, + // same underlying FakeClock, the Receiver should be ahead of the Sender, // which reflects the simulated one-way network packet travel time (of the // Sender Report). // @@ -422,11 +429,8 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) { // out of order, but such that each frame is completely received in-order. Also, // confirms that target playout delay changes are processed/applied correctly. TEST_F(ReceiverTest, ReceivesFramesInOrder) { - // Send the initial Sender Report with lip-sync timing information to - // "unblock" the Receiver. const Clock::time_point start_time = FakeClock::now(); - sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + ExchangeInitialReportPackets(); EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(10); for (int i = 0; i <= 9; ++i) { @@ -442,10 +446,15 @@ TEST_F(ReceiverTest, ReceivesFramesInOrder) { const int permutation = (i % 2) ? i : 0; sender()->SendRtpPackets(sender()->GetAllPacketIds(permutation)); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); + // The Receiver should immediately ACK once it has received all the RTP // packets to complete the frame. - RunTasksUntilIdle(); testing::Mock::VerifyAndClearExpectations(sender()); + + // Advance to next frame transmission time. + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration - + kRoundTripNetworkDelay); } // When the Receiver has all of the frames and they are complete, it should @@ -466,11 +475,8 @@ TEST_F(ReceiverTest, ReceivesFramesInOrder) { // order, and issues the appropriate ACK/NACK feedback to the Sender as it // realizes what it has and what it's missing. TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { - // Send the initial Sender Report with lip-sync timing information to - // "unblock" the Receiver. const Clock::time_point start_time = FakeClock::now(); - sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + ExchangeInitialReportPackets(); constexpr static int kOutOfOrderFrames[] = {3, 4, 2, 0, 1}; for (int i : kOutOfOrderFrames) { @@ -479,7 +485,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { case 3: { // Note that frame 4 will not yet be known to the Receiver, and so it // should not be mentioned in any of the feedback for this case. - EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1, + EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay)) .Times(AtLeast(1)); EXPECT_CALL( @@ -498,7 +504,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { } case 4: { - EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1, + EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay)) .Times(AtLeast(1)); EXPECT_CALL(*sender(), @@ -517,7 +523,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { } case 2: { - EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1, + EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay)) .Times(AtLeast(1)); EXPECT_CALL(*sender(), OnReceiverHasFrames(std::vector<FrameId>( @@ -567,11 +573,13 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { sender()->SetFrameBeingSent(SimulatedFrame(start_time, i)); sender()->SendRtpPackets(sender()->GetAllPacketIds(i)); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); + // While there are known incomplete frames, the Receiver should send RTCP // packets more frequently than the default "ping" interval. Thus, advancing // the clock by this much should result in several feedback reports // transmitted to the Sender. - AdvanceClockAndRunTasks(kRtcpReportInterval); + AdvanceClockAndRunTasks(kRtcpReportInterval - kRoundTripNetworkDelay); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); @@ -585,11 +593,8 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) { // by sending a Picture Loss Indicator (PLI) to the Sender, and then will // automatically stop sending the PLI once a key frame has been received. TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { - // Send the initial Sender Report with lip-sync timing information to - // "unblock" the Receiver. const Clock::time_point start_time = FakeClock::now(); - sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + ExchangeInitialReportPackets(); // Send and Receive three frames in-order, normally. for (int i = 0; i <= 2; ++i) { @@ -599,9 +604,12 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { .Times(1); sender()->SetFrameBeingSent(SimulatedFrame(start_time, i)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); + // Advance to next frame transmission time. + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration - + kRoundTripNetworkDelay); } ConsumeAndVerifyFrames(0, 2, start_time); @@ -609,7 +617,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { // decoder failure). Ensure the Sender is immediately notified. EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1); receiver()->RequestKeyFrame(); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); // Propagate request to Sender. testing::Mock::VerifyAndClearExpectations(sender()); // The Sender sends another frame that is not a key frame and, upon receipt, @@ -618,10 +626,10 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() + 3, kTargetPlayoutDelay)) .Times(1); - EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1); + EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(AtLeast(1)); sender()->SetFrameBeingSent(SimulatedFrame(start_time, 3)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration - kOneWayNetworkDelay); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); ConsumeAndVerifyFrames(3, 3, start_time); @@ -639,7 +647,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { key_frame.referenced_frame_id = key_frame.frame_id; sender()->SetFrameBeingSent(key_frame); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); @@ -647,7 +655,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { // RequestKeyFrame() should not set the PLI condition again. EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(0); receiver()->RequestKeyFrame(); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); testing::Mock::VerifyAndClearExpectations(sender()); // After consuming the requested key frame, the client should be able to set @@ -655,7 +663,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { ConsumeAndVerifyFrame(key_frame); EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1); receiver()->RequestKeyFrame(); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); testing::Mock::VerifyAndClearExpectations(sender()); } @@ -663,11 +671,8 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) { // full (i.e., when the consumer is not pulling them out of the queue). Since // the Receiver will stop ACK'ing frames, the Sender will become stalled. TEST_F(ReceiverTest, EatsItsFill) { - // Send the initial Sender Report with lip-sync timing information to - // "unblock" the Receiver. const Clock::time_point start_time = FakeClock::now(); - sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + ExchangeInitialReportPackets(); // Send and Receive the maximum possible number of frames in-order, normally. for (int i = 0; i < kMaxUnackedFrames; ++i) { @@ -678,7 +683,7 @@ TEST_F(ReceiverTest, EatsItsFill) { .Times(1); sender()->SetFrameBeingSent(SimulatedFrame(start_time, i)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); } @@ -692,11 +697,11 @@ TEST_F(ReceiverTest, EatsItsFill) { EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() + (ignored_frame - 1), kTargetPlayoutDelayChange)) - .Times(AtLeast(1)); + .Times(AtLeast(0)); EXPECT_CALL(*sender(), OnReceiverIsMissingPackets(_)).Times(0); sender()->SetFrameBeingSent(SimulatedFrame(start_time, ignored_frame)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - AdvanceClockAndRunTasks(kRtcpReportInterval); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); } @@ -706,7 +711,7 @@ TEST_F(ReceiverTest, EatsItsFill) { ConsumeAndVerifyFrames(0, 0, start_time); int no_longer_ignored_frame = ignored_frame; ++ignored_frame; - EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(1); + EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(AtLeast(1)); EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() + no_longer_ignored_frame, kTargetPlayoutDelayChange)) @@ -716,10 +721,11 @@ TEST_F(ReceiverTest, EatsItsFill) { sender()->SetFrameBeingSent( SimulatedFrame(start_time, no_longer_ignored_frame)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); // This second frame should be ignored, however. sender()->SetFrameBeingSent(SimulatedFrame(start_time, ignored_frame)); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); - AdvanceClockAndRunTasks(kRtcpReportInterval); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); testing::Mock::VerifyAndClearExpectations(sender()); testing::Mock::VerifyAndClearExpectations(consumer()); } @@ -728,11 +734,8 @@ TEST_F(ReceiverTest, EatsItsFill) { // but only as inter-frame data dependency requirements permit, and only if no // target playout delay change information would have been missed. TEST_F(ReceiverTest, DropsLateFrames) { - // Send the initial Sender Report with lip-sync timing information to - // "unblock" the Receiver. const Clock::time_point start_time = FakeClock::now(); - sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime()); - AdvanceClockAndRunTasks(kOneWayNetworkDelay); + ExchangeInitialReportPackets(); // Before any packets have been sent/received, the Receiver should indicate no // frames are ready. @@ -766,8 +769,8 @@ TEST_F(ReceiverTest, DropsLateFrames) { // is not exercising the logic meaningfully. ASSERT_LE(size_t{3}, sender()->GetAllPacketIds(0).size()); sender()->SendRtpPackets({FramePacketId{1}}); + AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration); } - RunTasksUntilIdle(); testing::Mock::VerifyAndClearExpectations(consumer()); testing::Mock::VerifyAndClearExpectations(sender()); EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame()); @@ -782,7 +785,7 @@ TEST_F(ReceiverTest, DropsLateFrames) { sender()->SetFrameBeingSent(frames[i]); sender()->SendRtpPackets(sender()->GetAllPacketIds(0)); } - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); testing::Mock::VerifyAndClearExpectations(consumer()); testing::Mock::VerifyAndClearExpectations(sender()); EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame()); @@ -800,7 +803,7 @@ TEST_F(ReceiverTest, DropsLateFrames) { sender()->SetFrameBeingSent(frames[i]); sender()->SendRtpPackets({FramePacketId{0}}); } - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); testing::Mock::VerifyAndClearExpectations(consumer()); testing::Mock::VerifyAndClearExpectations(sender()); EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame()); @@ -816,7 +819,7 @@ TEST_F(ReceiverTest, DropsLateFrames) { .Times(1); sender()->SetFrameBeingSent(frames[5]); sender()->SendRtpPackets({FramePacketId{0}}); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kRoundTripNetworkDelay); // Note: Consuming Frame 6 will trigger the checkpoint advancement, since the // call to AdvanceToNextFrame() contains the frame skipping/dropping logic. ConsumeAndVerifyFrame(frames[6]); @@ -826,7 +829,7 @@ TEST_F(ReceiverTest, DropsLateFrames) { // After consuming Frame 6, the Receiver knows Frame 7 is also available and // should have scheduled an immediate task to notify the Consumer of this. EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(1); - RunTasksUntilIdle(); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); testing::Mock::VerifyAndClearExpectations(consumer()); // Now consume Frame 7. This shouldn't trigger any further checkpoint @@ -834,10 +837,11 @@ TEST_F(ReceiverTest, DropsLateFrames) { EXPECT_CALL(*consumer(), OnFramesReady(_)).Times(0); EXPECT_CALL(*sender(), OnReceiverCheckpoint(_, _)).Times(0); ConsumeAndVerifyFrame(frames[7]); + AdvanceClockAndRunTasks(kOneWayNetworkDelay); testing::Mock::VerifyAndClearExpectations(consumer()); testing::Mock::VerifyAndClearExpectations(sender()); } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc index 30d587a24ed..696849bb501 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc @@ -9,11 +9,8 @@ #include "cast/streaming/packet_util.h" #include "util/saturate_cast.h" -using openscreen::saturate_cast; -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { RtcpCommonHeader::RtcpCommonHeader() = default; RtcpCommonHeader::~RtcpCommonHeader() = default; @@ -30,6 +27,9 @@ void RtcpCommonHeader::AppendFields(absl::Span<uint8_t>* buffer) const { FieldBitmask<int>(kRtcpReportCountFieldNumBits)); byte0 |= with.report_count; break; + case RtcpPacketType::kSourceDescription: + OSP_UNIMPLEMENTED(); + break; case RtcpPacketType::kApplicationDefined: case RtcpPacketType::kPayloadSpecific: switch (with.subtype) { @@ -242,5 +242,5 @@ absl::optional<RtcpReportBlock> RtcpReportBlock::ParseOne( RtcpSenderReport::RtcpSenderReport() = default; RtcpSenderReport::~RtcpSenderReport() = default; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h index 3b6959c0e67..25e2c2ed7e5 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h @@ -18,8 +18,8 @@ #include "cast/streaming/rtp_time.h" #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { struct RtcpCommonHeader { RtcpCommonHeader(); @@ -116,8 +116,7 @@ struct RtcpReportBlock { // Convenience helper to convert the given |local_clock_delay| to the // RtcpReportBlock::Delay timebase, then clamp and assign it to // |delay_since_last_report|. - void SetDelaySinceLastReport( - openscreen::platform::Clock::duration local_clock_delay); + void SetDelaySinceLastReport(Clock::duration local_clock_delay); // Serializes this report block in the first |kRtcpReportBlockSize| bytes of // the given |buffer| and adjusts |buffer| to point to the first byte after @@ -139,7 +138,7 @@ struct RtcpSenderReport { // common reference clock shared by all RTP streams; 2) the RTP timestamp on // the media capture/playout timeline. Together, these are used by a Receiver // to achieve A/V synchronization across RTP streams for playout. - openscreen::platform::Clock::time_point reference_time{}; + Clock::time_point reference_time{}; RtpTimeTicks rtp_timestamp; // The total number of RTP packets transmitted since the start of the session @@ -178,7 +177,7 @@ struct PacketNack { } }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTCP_COMMON_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc index 8d088c2c972..d593e4f02ee 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc @@ -11,10 +11,8 @@ #include "gtest/gtest.h" #include "platform/api/time.h" -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { namespace { template <typename T> @@ -308,5 +306,5 @@ TEST(RtcpCommonTest, ComputesDelayForReportBlocks) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc index a8b8ab6a0ef..c9bf34ffbbf 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc @@ -6,12 +6,12 @@ #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { RtcpSession::RtcpSession(Ssrc sender_ssrc, Ssrc receiver_ssrc, - openscreen::platform::Clock::time_point start_time) + Clock::time_point start_time) : sender_ssrc_(sender_ssrc), receiver_ssrc_(receiver_ssrc), ntp_converter_(start_time) { @@ -22,5 +22,5 @@ RtcpSession::RtcpSession(Ssrc sender_ssrc, RtcpSession::~RtcpSession() = default; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h index d896cb946bd..0c0b6a45364 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h @@ -8,8 +8,8 @@ #include "cast/streaming/ntp_time.h" #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { // Session-level configuration and shared components for the RTCP messaging // associated with a single Cast RTP stream. Multiple packet serialization and @@ -21,7 +21,7 @@ class RtcpSession { // world" wall time. RtcpSession(Ssrc sender_ssrc, Ssrc receiver_ssrc, - openscreen::platform::Clock::time_point start_time); + Clock::time_point start_time); ~RtcpSession(); Ssrc sender_ssrc() const { return sender_ssrc_; } @@ -36,7 +36,7 @@ class RtcpSession { NtpTimeConverter ntp_converter_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTCP_SESSION_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc index 0c3591299af..215b9cfcec0 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc @@ -4,8 +4,8 @@ #include "cast/streaming/rtp_defines.h" +namespace openscreen { namespace cast { -namespace streaming { bool IsRtpPayloadType(uint8_t raw_byte) { switch (static_cast<RtpPayloadType>(raw_byte)) { @@ -31,6 +31,7 @@ bool IsRtcpPacketType(uint8_t raw_byte) { switch (static_cast<RtcpPacketType>(raw_byte)) { case RtcpPacketType::kSenderReport: case RtcpPacketType::kReceiverReport: + case RtcpPacketType::kSourceDescription: case RtcpPacketType::kApplicationDefined: case RtcpPacketType::kPayloadSpecific: case RtcpPacketType::kExtendedReports: @@ -42,5 +43,5 @@ bool IsRtcpPacketType(uint8_t raw_byte) { return false; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h index 94294cd6394..335c06a04d1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h @@ -7,8 +7,8 @@ #include <stdint.h> +namespace openscreen { namespace cast { -namespace streaming { // Note: Cast Streaming uses a subset of the messages in the RTP/RTCP // specification, but also adds some of its own extensions. See: @@ -90,6 +90,7 @@ enum class RtpPayloadType : uint8_t { kVideoVp8 = 100, kVideoH264 = 101, kVideoVarious = 102, // Codec being used is not fixed. + kVideoLast = 102, // Some AndroidTV receivers require the payload type for audio to be 127, and // video to be 96; regardless of the codecs actually being used. This is @@ -141,6 +142,7 @@ enum class RtcpPacketType : uint8_t { kSenderReport = 200, kReceiverReport = 201, + kSourceDescription = 202, kApplicationDefined = 204, kPayloadSpecific = 206, kExtendedReports = 207, @@ -338,7 +340,7 @@ constexpr int kRtcpReceiverReferenceTimeReportBlockSize = 8; // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ constexpr int kRtcpPictureLossIndicatorHeaderSize = 8; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTP_DEFINES_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc index 5fbf005abeb..b7a838e3923 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc @@ -10,10 +10,8 @@ #include "cast/streaming/packet_util.h" #include "util/logging.h" -using openscreen::ReadBigEndian; - +namespace openscreen { namespace cast { -namespace streaming { RtpPacketParser::RtpPacketParser(Ssrc sender_ssrc) : sender_ssrc_(sender_ssrc), highest_rtp_frame_id_(FrameId::first()) {} @@ -114,5 +112,5 @@ absl::optional<RtpPacketParser::ParseResult> RtpPacketParser::Parse( RtpPacketParser::ParseResult::ParseResult() = default; RtpPacketParser::ParseResult::~ParseResult() = default; -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h index b2be4c5394b..b8ce126f42d 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h @@ -14,8 +14,8 @@ #include "cast/streaming/rtp_time.h" #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { // Parses RTP packets for all frames in the same Cast RTP stream. One // RtpPacketParser instance should be used for all RTP packets having the same @@ -72,7 +72,7 @@ class RtpPacketParser { FrameId highest_rtp_frame_id_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTP_PACKET_PARSER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc index 24169ce7a7a..73bde613e78 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc @@ -7,8 +7,8 @@ #include "cast/streaming/rtp_packet_parser.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - using cast::streaming::RtpPacketParser; - using cast::streaming::Ssrc; + using openscreen::cast::RtpPacketParser; + using openscreen::cast::Ssrc; constexpr Ssrc kSenderSsrcInSeedCorpus = 0x01020304; RtpPacketParser parser(kSenderSsrcInSeedCorpus); diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc index 927521c001c..5f9db875560 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc @@ -8,11 +8,8 @@ #include "gtest/gtest.h" #include "util/big_endian.h" -using openscreen::ReadBigEndian; -using openscreen::WriteBigEndian; - +namespace openscreen { namespace cast { -namespace streaming { namespace { // Tests that a simple packet for a key frame can be parsed. @@ -309,5 +306,5 @@ TEST(RtpPacketParserTest, RejectsPacketWithBadFramePacketIds) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc index 675a4a97cc5..610791fcb43 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc @@ -14,10 +14,8 @@ #include "util/integer_division.h" #include "util/logging.h" -using openscreen::platform::Clock; - +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -124,7 +122,7 @@ absl::Span<uint8_t> RtpPacketizer::GeneratePacket(const EncryptedFrame& frame, int RtpPacketizer::ComputeNumberOfPackets(const EncryptedFrame& frame) const { // The total number of packets is computed by assuming the payload will be // split-up across as few packets as possible. - int num_packets = openscreen::DividePositivesRoundingUp( + int num_packets = DividePositivesRoundingUp( static_cast<int>(frame.data.size()), max_payload_size()); // Edge case: There must always be at least one packet, even when there are no // payload bytes. Some audio codecs, for example, use zero bytes to represent @@ -135,5 +133,5 @@ int RtpPacketizer::ComputeNumberOfPackets(const EncryptedFrame& frame) const { return num_packets <= int{kMaxAllowedFramePacketId} ? num_packets : -1; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h index a3811948b5f..7e4f59b77bc 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h @@ -12,8 +12,8 @@ #include "cast/streaming/rtp_defines.h" #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { // Transforms a logical sequence of EncryptedFrames into RTP packets for // transmission. A single instance of RtpPacketizer should be used for all the @@ -47,6 +47,15 @@ class RtpPacketizer { // packetized. int ComputeNumberOfPackets(const EncryptedFrame& frame) const; + // See rtp_defines.h for wire-format diagram. + static constexpr int kBaseRtpHeaderSize = + // Plus one byte, because this implementation always includes the 8-bit + // Reference Frame ID field. + kRtpPacketMinValidSize + 1; + static constexpr int kAdaptiveLatencyHeaderSize = 4; + static constexpr int kMaxRtpHeaderSize = + kBaseRtpHeaderSize + kAdaptiveLatencyHeaderSize; + private: int max_payload_size() const { // Start with the configured max packet size, then subtract reserved space @@ -64,18 +73,9 @@ class RtpPacketizer { // re-transmitted, must have different sequence numbers (within wrap-around // concerns) per the RTP spec. uint16_t sequence_number_; - - // See rtp_defines.h for wire-format diagram. - static constexpr int kBaseRtpHeaderSize = - // Plus one byte, because this implementation always includes the 8-bit - // Reference Frame ID field. - kRtpPacketMinValidSize + 1; - static constexpr int kAdaptiveLatencyHeaderSize = 4; - static constexpr int kMaxRtpHeaderSize = - kBaseRtpHeaderSize + kAdaptiveLatencyHeaderSize; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTP_PACKETIZER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc index 52b460de2dd..d99eefa011e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc @@ -11,8 +11,8 @@ #include "cast/streaming/ssrc.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { constexpr RtpPayloadType kPayloadType = RtpPayloadType::kAudioOpus; @@ -42,7 +42,7 @@ class RtpPacketizerTest : public testing::Test { frame.frame_id = frame_id; frame.referenced_frame_id = is_key_frame ? frame_id : (frame_id - 1); frame.rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(987); - frame.reference_time = openscreen::platform::Clock::now(); + frame.reference_time = Clock::now(); frame.new_playout_delay = new_playout_delay; std::unique_ptr<uint8_t[]> buffer(new uint8_t[payload_size]); @@ -203,5 +203,5 @@ TEST_F(RtpPacketizerTest, GeneratesPacketForRetransmission) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc index 3f73715eb57..ee6078d45aa 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc @@ -6,8 +6,8 @@ #include <sstream> +namespace openscreen { namespace cast { -namespace streaming { std::ostream& operator<<(std::ostream& out, const RtpTimeDelta rhs) { if (rhs.value_ >= 0) @@ -21,5 +21,5 @@ std::ostream& operator<<(std::ostream& out, const RtpTimeTicks rhs) { return out << "RTP@" << rhs.value_; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h index 0a7f5d86458..536e7dd3ed1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h @@ -11,12 +11,14 @@ #include <cmath> #include <limits> #include <sstream> +#include <type_traits> #include "cast/streaming/expanded_value_base.h" #include "platform/api/time.h" +#include "util/saturate_cast.h" +namespace openscreen { namespace cast { -namespace streaming { // Forward declarations (see below). class RtpTimeDelta; @@ -155,25 +157,15 @@ class RtpTimeDelta : public ExpandedValueBase<int64_t, RtpTimeDelta> { constexpr int64_t value() const { return value_; } template <typename Rep> - static Rep ToNearestRepresentativeValue(double ticks) { - if (ticks <= std::numeric_limits<Rep>::min()) { - return std::numeric_limits<Rep>::min(); - } else if (ticks >= std::numeric_limits<Rep>::max()) { - return std::numeric_limits<Rep>::max(); - } + static std::enable_if_t<std::is_floating_point<Rep>::value, Rep> + ToNearestRepresentativeValue(double ticks) { + return Rep(ticks); + } - static_assert( - std::is_floating_point<Rep>::value || - (std::is_integral<Rep>::value && - sizeof(Rep) <= sizeof(decltype(llround(ticks)))), - "Rep must be an integer (<= 64 bits) or a floating-point type."); - if (std::is_floating_point<Rep>::value) { - return Rep(ticks); - } - if (sizeof(Rep) <= sizeof(decltype(lround(ticks)))) { - return Rep(lround(ticks)); - } - return Rep(llround(ticks)); + template <typename Rep> + static std::enable_if_t<std::is_integral<Rep>::value, Rep> + ToNearestRepresentativeValue(double ticks) { + return rounded_saturate_cast<Rep>(ticks); } }; @@ -262,7 +254,7 @@ class RtpTimeTicks : public ExpandedValueBase<int64_t, RtpTimeTicks> { constexpr int64_t value() const { return value_; } }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_RTP_TIME_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc index 44312e7bfaa..8427a5ea201 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc @@ -6,8 +6,8 @@ #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { // Tests that conversions between std::chrono durations and RtpTimeDelta are // accurate. Note that this implicitly tests the conversions to/from @@ -71,5 +71,5 @@ TEST(RtpTimeDeltaTest, ConversionToAndFromDurations) { .ToDuration<microseconds>(kTimebase)); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.cc b/chromium/third_party/openscreen/src/cast/streaming/sender.cc new file mode 100644 index 00000000000..fa3268a9f68 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender.cc @@ -0,0 +1,541 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/sender.h" + +#include <algorithm> +#include <ratio> // NOLINT + +#include "cast/streaming/session_config.h" +#include "util/logging.h" +#include "util/std_util.h" + +namespace openscreen { +namespace cast { + +using std::chrono::duration_cast; +using std::chrono::microseconds; +using std::chrono::milliseconds; + +using openscreen::operator<<; // For std::chrono::duration logging. + +Sender::Sender(Environment* environment, + SenderPacketRouter* packet_router, + const SessionConfig& config, + RtpPayloadType rtp_payload_type) + : packet_router_(packet_router), + rtcp_session_(config.sender_ssrc, + config.receiver_ssrc, + environment->now()), + rtcp_parser_(&rtcp_session_, this), + sender_report_builder_(&rtcp_session_), + rtp_packetizer_(rtp_payload_type, + config.sender_ssrc, + packet_router_->max_packet_size()), + rtp_timebase_(config.rtp_timebase), + crypto_(config.aes_secret_key, config.aes_iv_mask), + target_playout_delay_(config.target_playout_delay) { + OSP_DCHECK(packet_router_); + OSP_DCHECK_NE(rtcp_session_.sender_ssrc(), rtcp_session_.receiver_ssrc()); + OSP_DCHECK_GT(rtp_timebase_, 0); + OSP_DCHECK(target_playout_delay_ > milliseconds::zero()); + + pending_sender_report_.reference_time = SenderPacketRouter::kNever; + + packet_router_->OnSenderCreated(rtcp_session_.receiver_ssrc(), this); +} + +Sender::~Sender() { + packet_router_->OnSenderDestroyed(rtcp_session_.receiver_ssrc()); +} + +void Sender::SetObserver(Sender::Observer* observer) { + observer_ = observer; +} + +int Sender::GetInFlightFrameCount() const { + return num_frames_in_flight_; +} + +Clock::duration Sender::GetInFlightMediaDuration( + RtpTimeTicks next_frame_rtp_timestamp) const { + if (num_frames_in_flight_ == 0) { + return Clock::duration::zero(); // No frames are currently in-flight. + } + + const PendingFrameSlot& oldest_slot = *get_slot_for(checkpoint_frame_id_ + 1); + // Note: The oldest slot's frame cannot have been canceled because the + // protocol does not allow ACK'ing this particular frame without also moving + // the checkpoint forward. See "CST2 feedback" discussion in rtp_defines.h. + OSP_DCHECK(oldest_slot.is_active_for_frame(checkpoint_frame_id_ + 1)); + + return (next_frame_rtp_timestamp - oldest_slot.frame->rtp_timestamp) + .ToDuration<Clock::duration>(rtp_timebase_); +} + +Clock::duration Sender::GetMaxInFlightMediaDuration() const { + // Assumption: The total amount of allowed in-flight media should equal the + // half of the playout delay window, plus the amount of time it takes to + // receive an ACK from the Receiver. + // + // Why half of the playout delay window? It's assumed here that capture and + // media encoding, which occur before EnqueueFrame() is called, are executing + // within the first half of the playout delay window. This leaves the second + // half for executing all network transmits/re-transmits, plus decoding and + // play-out at the Receiver. + return (target_playout_delay_ / 2) + (round_trip_time_ / 2); +} + +bool Sender::NeedsKeyFrame() const { + return last_enqueued_key_frame_id_ <= picture_lost_at_frame_id_; +} + +FrameId Sender::GetNextFrameId() const { + return last_enqueued_frame_id_ + 1; +} + +Sender::EnqueueFrameResult Sender::EnqueueFrame(const EncodedFrame& frame) { + // Assume the fields of the |frame| have all been set correctly, with + // monotonically increasing timestamps and a valid pointer to the data. + OSP_DCHECK_EQ(frame.frame_id, GetNextFrameId()); + OSP_DCHECK_GE(frame.referenced_frame_id, FrameId::first()); + if (frame.frame_id != FrameId::first()) { + OSP_DCHECK_GT(frame.rtp_timestamp, pending_sender_report_.rtp_timestamp); + OSP_DCHECK_GT(frame.reference_time, pending_sender_report_.reference_time); + } + OSP_DCHECK(frame.data.data()); + + // Check whether enqueuing the frame would exceed the design limit for the + // span of FrameIds. Even if |num_frames_in_flight_| is less than + // kMaxUnackedFrames, it's the span of FrameIds that is restricted. + if ((frame.frame_id - checkpoint_frame_id_) > kMaxUnackedFrames) { + return REACHED_ID_SPAN_LIMIT; + } + + // Check whether enqueuing the frame would exceed the current maximum media + // duration limit. + if (GetInFlightMediaDuration(frame.rtp_timestamp) > + GetMaxInFlightMediaDuration()) { + return MAX_DURATION_IN_FLIGHT; + } + + // Encrypt the frame and initialize the slot tracking its sending. + PendingFrameSlot* const slot = get_slot_for(frame.frame_id); + OSP_DCHECK(!slot->frame); + slot->frame = crypto_.Encrypt(frame); + const int packet_count = rtp_packetizer_.ComputeNumberOfPackets(*slot->frame); + if (packet_count <= 0) { + slot->frame.reset(); + return PAYLOAD_TOO_LARGE; + } + slot->send_flags.Resize(packet_count, YetAnotherBitVector::SET); + slot->packet_sent_times.assign(packet_count, SenderPacketRouter::kNever); + + // Officially record the "enqueue." + ++num_frames_in_flight_; + last_enqueued_frame_id_ = slot->frame->frame_id; + OSP_DCHECK_LE(num_frames_in_flight_, + last_enqueued_frame_id_ - checkpoint_frame_id_); + if (slot->frame->dependency == EncodedFrame::KEY_FRAME) { + last_enqueued_key_frame_id_ = slot->frame->frame_id; + } + + // Update the target playout delay, if necessary. + if (slot->frame->new_playout_delay > milliseconds::zero()) { + target_playout_delay_ = slot->frame->new_playout_delay; + playout_delay_change_at_frame_id_ = slot->frame->frame_id; + } + + // Update the lip-sync information for the next Sender Report. + pending_sender_report_.reference_time = slot->frame->reference_time; + pending_sender_report_.rtp_timestamp = slot->frame->rtp_timestamp; + + // If the round trip time hasn't been computed yet, immediately send a RTCP + // packet (i.e., before the RTP packets are sent). The RTCP packet will + // provide a Sender Report which contains the required lip-sync information + // the Receiver needs for timing the media playout. + // + // Detail: Working backwards, if the round trip time is not known, then this + // Sender has never processed a Receiver Report. Thus, the Receiver has never + // provided a Receiver Report, which it can only do after having processed a + // Sender Report from this Sender. Thus, this Sender really needs to send + // that, right now! + if (round_trip_time_ == Clock::duration::zero()) { + packet_router_->RequestRtcpSend(rtcp_session_.receiver_ssrc()); + } + + // Re-activate RTP sending if it was suspended. + packet_router_->RequestRtpSend(rtcp_session_.receiver_ssrc()); + + return OK; +} + +void Sender::OnReceivedRtcpPacket(Clock::time_point arrival_time, + absl::Span<const uint8_t> packet) { + rtcp_packet_arrival_time_ = arrival_time; + // This call to Parse() invoke zero or more of the OnReceiverXYZ() methods in + // the current call stack: + if (rtcp_parser_.Parse(packet, last_enqueued_frame_id_)) { + packet_router_->OnRtcpReceived(arrival_time, round_trip_time_); + } +} + +absl::Span<uint8_t> Sender::GetRtcpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) { + if (pending_sender_report_.reference_time == SenderPacketRouter::kNever) { + // Cannot send a report if one is not available (i.e., a frame has never + // been enqueued). + return buffer.subspan(0, 0); + } + + // The Sender Report to be sent is a snapshot of the "pending Sender Report," + // but with its timestamp fields modified. First, the reference time is set to + // the RTCP packet's send time. Then, the corresponding RTP timestamp is + // translated to match (for lip-sync). + RtcpSenderReport sender_report = pending_sender_report_; + sender_report.reference_time = send_time; + sender_report.rtp_timestamp += RtpTimeDelta::FromDuration( + sender_report.reference_time - pending_sender_report_.reference_time, + rtp_timebase_); + + return sender_report_builder_.BuildPacket(sender_report, buffer).first; +} + +absl::Span<uint8_t> Sender::GetRtpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) { + ChosenPacket chosen = ChooseNextRtpPacketNeedingSend(); + + // If no packets need sending (i.e., all packets have been sent at least once + // and do not need to be re-sent yet), check whether a Kickstart packet should + // be sent. It's possible that there has been complete packet loss of some + // frames, and the Receiver may not be aware of the existence of the latest + // frame(s). Kickstarting is the only way the Receiver can discover the newer + // frames it doesn't know about. + if (!chosen) { + const ChosenPacketAndWhen kickstart = ChooseKickstartPacket(); + if (kickstart.when > send_time) { + // Nothing to send, so return "empty" signal to the packet router. The + // packet router will suspend RTP sending until this Sender explicitly + // resumes it. + return buffer.subspan(0, 0); + } + chosen = kickstart; + OSP_DCHECK(chosen); + } + + const absl::Span<uint8_t> result = rtp_packetizer_.GeneratePacket( + *chosen.slot->frame, chosen.packet_id, buffer); + chosen.slot->send_flags.Clear(chosen.packet_id); + chosen.slot->packet_sent_times[chosen.packet_id] = send_time; + + ++pending_sender_report_.send_packet_count; + // According to RFC3550, the octet count does not include the RTP header. The + // following is just a good approximation, however, because the header size + // will very infrequently be 4 bytes greater (see + // RtpPacketizer::kAdaptiveLatencyHeaderSize). No known Cast Streaming + // Receiver implementations use this for anything, and so this should be fine. + const int approximate_octet_count = + static_cast<int>(result.size()) - RtpPacketizer::kBaseRtpHeaderSize; + OSP_DCHECK_GE(approximate_octet_count, 0); + pending_sender_report_.send_octet_count += approximate_octet_count; + + return result; +} + +Clock::time_point Sender::GetRtpResumeTime() { + if (ChooseNextRtpPacketNeedingSend()) { + return Alarm::kImmediately; + } + return ChooseKickstartPacket().when; +} + +void Sender::OnReceiverReferenceTimeAdvanced(Clock::time_point reference_time) { + // Not used. +} + +void Sender::OnReceiverReport(const RtcpReportBlock& receiver_report) { + OSP_DCHECK_NE(rtcp_packet_arrival_time_, SenderPacketRouter::kNever); + + const Clock::duration total_delay = + rtcp_packet_arrival_time_ - + sender_report_builder_.GetRecentReportTime( + receiver_report.last_status_report_id, rtcp_packet_arrival_time_); + const auto non_network_delay = + duration_cast<Clock::duration>(receiver_report.delay_since_last_report); + + // Round trip time measurement: This is the time elapsed since the Sender + // Report was sent, minus the time the Receiver did other stuff before sending + // the Receiver Report back. + // + // If the round trip time seems to be less than or equal to zero, assume clock + // imprecision by one or both peers caused a bad value to be calculated. The + // true value is likely very close to zero (i.e., this is ideal network + // behavior); and so just represent this as 75 µs, an optimistic + // wired-Ethernet LAN ping time. + constexpr auto kNearZeroRoundTripTime = + duration_cast<Clock::duration>(microseconds(75)); + static_assert(kNearZeroRoundTripTime > Clock::duration::zero(), + "More precision in Clock::duration needed!"); + const Clock::duration measurement = + std::max(total_delay - non_network_delay, kNearZeroRoundTripTime); + + // Validate the measurement by using the current target playout delay as a + // "reasonable upper-bound." It's certainly possible that the actual network + // round-trip time could exceed the target playout delay, but that would mean + // the current network performance is totally inadequate for streaming anyway. + if (measurement > target_playout_delay_) { + OSP_LOG_WARN << "Invalidating a round-trip time measurement (" + << measurement + << ") since it exceeds the current target playout delay (" + << target_playout_delay_ << ")."; + return; + } + + // Measurements will typically have high variance. Use a simple smoothing + // filter to track a short-term average that changes less drastically. + if (round_trip_time_ == Clock::duration::zero()) { + round_trip_time_ = measurement; + } else { + // Arbitrary constant, to provide 1/8 weight to the new measurement, and 7/8 + // weight to the old estimate, which seems to work well for de-noising the + // estimate. + constexpr int kInertia = 7; + round_trip_time_ = + (kInertia * round_trip_time_ + measurement) / (kInertia + 1); + } + // TODO(miu): Add tracing event here to note the updated RTT. +} + +void Sender::OnReceiverIndicatesPictureLoss() { + // The Receiver will continue the PLI notifications until it has received a + // key frame. Thus, if a key frame is already in-flight, don't make a state + // change that would cause this Sender to force another expensive key frame. + if (checkpoint_frame_id_ < last_enqueued_key_frame_id_) { + return; + } + + picture_lost_at_frame_id_ = checkpoint_frame_id_; + + if (observer_) { + observer_->OnPictureLost(); + } + + // Note: It may seem that all pending frames should be canceled until + // EnqueueFrame() is called with a key frame. However: + // + // 1. The Receiver should still be the main authority on what frames/packets + // are being ACK'ed and NACK'ed. + // + // 2. It may be desirable for the Receiver to be "limping along" in the + // meantime. For example, video may be corrupted but mostly watchable, + // and so it's best for the Sender to continue sending the non-key frames + // until the Receiver indicates otherwise. +} + +void Sender::OnReceiverCheckpoint(FrameId frame_id, + milliseconds playout_delay) { + if (frame_id > last_enqueued_frame_id_) { + OSP_LOG_ERROR + << "Ignoring checkpoint for " << latest_expected_frame_id_ + << " because this Sender could not have sent any frames after " + << last_enqueued_frame_id_ << '.'; + return; + } + // CompoundRtcpParser should guarantee this: + OSP_DCHECK(playout_delay >= milliseconds::zero()); + + while (checkpoint_frame_id_ < frame_id) { + ++checkpoint_frame_id_; + CancelPendingFrame(checkpoint_frame_id_); + } + latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id); + + if (playout_delay != target_playout_delay_ && + frame_id >= playout_delay_change_at_frame_id_) { + OSP_LOG_WARN << "Sender's target playout delay (" << target_playout_delay_ + << ") disagrees with the Receiver's (" << playout_delay << ")"; + } +} + +void Sender::OnReceiverHasFrames(std::vector<FrameId> acks) { + OSP_DCHECK(!acks.empty() && AreElementsSortedAndUnique(acks)); + + if (acks.back() > last_enqueued_frame_id_) { + OSP_LOG_ERROR << "Ignoring individual frame ACKs: ACKing frame " + << latest_expected_frame_id_ + << " is invalid because this Sender could not have sent any " + "frames after " + << last_enqueued_frame_id_ << '.'; + return; + } + + for (FrameId id : acks) { + CancelPendingFrame(id); + } + latest_expected_frame_id_ = std::max(latest_expected_frame_id_, acks.back()); +} + +void Sender::OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) { + OSP_DCHECK(!nacks.empty() && AreElementsSortedAndUnique(nacks)); + OSP_DCHECK_NE(rtcp_packet_arrival_time_, SenderPacketRouter::kNever); + + // This is a point-in-time threshold that indicates whether each NACK will + // trigger a packet retransmit. The threshold is based on the network round + // trip time because a Receiver's NACK may have been issued while the needed + // packet was in-flight from the Sender. In such cases, the Receiver's NACK is + // likely stale and this Sender should not redundantly re-transmit the packet + // again. + const Clock::time_point too_recent_a_send_time = + rtcp_packet_arrival_time_ - round_trip_time_; + + // Iterate over all the NACKs... + bool need_to_send = false; + for (auto nack_it = nacks.begin(); nack_it != nacks.end();) { + // Find the slot associated with the NACK's frame ID. + const FrameId frame_id = nack_it->frame_id; + PendingFrameSlot* slot = nullptr; + if (frame_id <= last_enqueued_frame_id_) { + PendingFrameSlot* const candidate_slot = get_slot_for(frame_id); + if (candidate_slot->is_active_for_frame(frame_id)) { + slot = candidate_slot; + } + } + + // If no slot was found (i.e., the NACK is invalid) for the frame, skip-over + // all other NACKs for the same frame. While it seems to be a bug that the + // Receiver would attempt to NACK a frame that does not yet exist, this can + // happen in rare cases where RTCP packets arrive out-of-order (i.e., the + // network shuffled them). + if (!slot) { + // TODO(miu): Add tracing event here to record this. + for (++nack_it; nack_it != nacks.end() && nack_it->frame_id == frame_id; + ++nack_it) + ; + continue; + } + + latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id); + + const auto HandleIndividualNack = [&](FramePacketId packet_id) { + if (slot->packet_sent_times[packet_id] <= too_recent_a_send_time) { + slot->send_flags.Set(packet_id); + need_to_send = true; + } + }; + const FramePacketId range_end = slot->packet_sent_times.size(); + if (nack_it->packet_id == kAllPacketsLost) { + for (FramePacketId packet_id = 0; packet_id < range_end; ++packet_id) { + HandleIndividualNack(packet_id); + } + ++nack_it; + } else { + do { + if (nack_it->packet_id < range_end) { + HandleIndividualNack(nack_it->packet_id); + } else { + OSP_LOG_WARN + << "Ignoring NACK for packet that doesn't exist in frame " + << frame_id << ": " << static_cast<int>(nack_it->packet_id); + } + ++nack_it; + } while (nack_it != nacks.end() && nack_it->frame_id == frame_id); + } + } + + if (need_to_send) { + packet_router_->RequestRtpSend(rtcp_session_.receiver_ssrc()); + } +} + +Sender::ChosenPacket Sender::ChooseNextRtpPacketNeedingSend() { + // Find the oldest packet needing to be sent (or re-sent). + for (FrameId frame_id = checkpoint_frame_id_ + 1; + frame_id <= last_enqueued_frame_id_; ++frame_id) { + PendingFrameSlot* const slot = get_slot_for(frame_id); + if (!slot->is_active_for_frame(frame_id)) { + continue; // Frame was canceled. None of its packets need to be sent. + } + const FramePacketId packet_id = slot->send_flags.FindFirstSet(); + if (packet_id < slot->send_flags.size()) { + return {slot, packet_id}; + } + } + + return {}; // Nothing needs to be sent. +} + +Sender::ChosenPacketAndWhen Sender::ChooseKickstartPacket() { + if (latest_expected_frame_id_ >= last_enqueued_frame_id_) { + // Since the Receiver must know about all of the frames currently queued, no + // Kickstart packet is necessary. + return {}; + } + + // The Kickstart packet is always in the last-enqueued frame, so that the + // Receiver will know about every frame the Sender has. However, which packet + // should be chosen? Any would do, since all packets contain the frame's total + // packet count. For historical reasons, all sender implementations have + // always just sent the last packet; and so that tradition is continued here. + ChosenPacketAndWhen chosen; + chosen.slot = get_slot_for(last_enqueued_frame_id_); + // Note: This frame cannot have been canceled since + // |latest_expected_frame_id_| hasn't yet reached this point. + OSP_DCHECK(chosen.slot->is_active_for_frame(last_enqueued_frame_id_)); + chosen.packet_id = chosen.slot->send_flags.size() - 1; + + const Clock::time_point time_last_sent = + chosen.slot->packet_sent_times[chosen.packet_id]; + // Sanity-check: This method should not be called to choose a packet while + // there are still unsent packets. + OSP_DCHECK_NE(time_last_sent, SenderPacketRouter::kNever); + + // The desired Kickstart interval is a fraction of the total + // |target_playout_delay_|. The reason for the specific ratio here is based on + // lost knowledge (from legacy implementations); but it makes sense (i.e., to + // be a good "network citizen") to be less aggressive for larger playout delay + // windows, and more aggressive for shorter ones to avoid too-late packet + // arrivals. + using kWaitFraction = std::ratio<1, 20>; + const Clock::duration desired_kickstart_interval = + duration_cast<Clock::duration>(target_playout_delay_) * + kWaitFraction::num / kWaitFraction::den; + // The actual interval used is increased, if current network performance + // warrants waiting longer. Don't send a Kickstart packet until no NACKs + // have been received for two network round-trip periods. + constexpr int kLowerBoundRoundTrips = 2; + const Clock::duration kickstart_interval = std::max( + desired_kickstart_interval, round_trip_time_ * kLowerBoundRoundTrips); + chosen.when = time_last_sent + kickstart_interval; + + return chosen; +} + +void Sender::CancelPendingFrame(FrameId frame_id) { + PendingFrameSlot* const slot = get_slot_for(frame_id); + if (!slot->is_active_for_frame(frame_id)) { + return; // Frame was already canceled. + } + + packet_router_->OnPayloadReceived( + slot->frame->data.size(), rtcp_packet_arrival_time_, round_trip_time_); + + slot->frame.reset(); + OSP_DCHECK_GT(num_frames_in_flight_, 0); + --num_frames_in_flight_; + if (observer_) { + observer_->OnFrameCanceled(frame_id); + } +} + +void Sender::Observer::OnFrameCanceled(FrameId frame_id) {} +void Sender::Observer::OnPictureLost() {} +Sender::Observer::~Observer() = default; + +Sender::PendingFrameSlot::PendingFrameSlot() = default; +Sender::PendingFrameSlot::~PendingFrameSlot() = default; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.h b/chromium/third_party/openscreen/src/cast/streaming/sender.h new file mode 100644 index 00000000000..48f651c04de --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender.h @@ -0,0 +1,318 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_SENDER_H_ +#define CAST_STREAMING_SENDER_H_ + +#include <stdint.h> + +#include <array> +#include <chrono> // NOLINT +#include <vector> + +#include "absl/types/span.h" +#include "cast/streaming/compound_rtcp_parser.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/frame_crypto.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/rtp_defines.h" +#include "cast/streaming/rtp_packetizer.h" +#include "cast/streaming/rtp_time.h" +#include "cast/streaming/sender_packet_router.h" +#include "cast/streaming/sender_report_builder.h" +#include "platform/api/time.h" +#include "util/yet_another_bit_vector.h" + +namespace openscreen { +namespace cast { + +class Environment; +struct SessionConfig; + +// The Cast Streaming Sender, a peer corresponding to some Cast Streaming +// Receiver at the other end of a network link. See class level comments for +// Receiver for a high-level overview. +// +// The Sender is the peer responsible for enqueuing EncodedFrames for streaming, +// guaranteeing their delivery to a Receiver, and handling feedback events from +// a Receiver. Some feedback events are used for managing the Sender's internal +// queue of in-flight frames, requesting network packet re-transmits, etc.; +// while others are exposed via the Sender's public interface. For example, +// sometimes the Receiver signals that it needs a a key frame to resolve a +// picture loss condition, and the modules upstream of the Sender (e.g., where +// encoding happens) should call NeedsKeyFrame() to check for, and handle that. +// +// There are usually one or two Senders in a streaming session, one for audio +// and one for video. Both senders work with the same SenderPacketRouter +// instance to schedule their transmission of packets, and provide the necessary +// metrics for estimating bandwidth utilization and availability. +// +// It is the responsibility of upstream code modules to handle congestion +// control. With respect to this Sender, that means the media encoding bit rate +// should be throttled based on network bandwidth availability. This Sender does +// not do any throttling, only flow-control. In other words, this Sender can +// only manage its in-flight queue of frames, and if that queue grows too large, +// it will eventually reject further enqueuing. +// +// General usage: A client should check the in-flight media duration frequently +// to decide when to pause encoding, to avoid wasting system resources on +// encoding frames that will likely be rejected by the Sender. The client should +// also frequently call NeedsKeyFrame() and, when this returns true, direct its +// encoder to produce a key frame soon. Finally, when using EnqueueFrame(), an +// EncodedFrame struct should be prepared with its frame_id field set to +// whatever GetNextFrameId() returns. Please see method comments for +// more-detailed usage info. +class Sender final : public SenderPacketRouter::Sender, + public CompoundRtcpParser::Client { + public: + // Interface for receiving notifications about events of possible interest. + // Handling each of these is optional, but some may be mandatory for certain + // applications (see method comments below). + class Observer { + public: + // Called when a frame was canceled. "Canceled" means that the Receiver has + // either acknowledged successful receipt of the frame or has decided to + // skip over it. Note: Frame cancellations may occur out-of-order. + virtual void OnFrameCanceled(FrameId frame_id); + + // Called when a Receiver begins reporting picture loss, and there is no key + // frame currently enqueued in the Sender. The application should enqueue a + // key frame as soon as possible. Note: An application that pauses frame + // sending (e.g., screen mirroring when the screen is not changing) should + // use this notification to send an out-of-band "refresh frame," encoded as + // a key frame. + virtual void OnPictureLost(); + + protected: + virtual ~Observer(); + }; + + // Result codes for EnqueueFrame(). + enum EnqueueFrameResult { + // The frame has been queued for sending. + OK, + + // The frame's payload was too large. This is typically triggered when + // submitting a payload of several dozen megabytes or more. This result code + // likely indicates some kind of upstream bug. + PAYLOAD_TOO_LARGE, + + // The span of FrameIds is too large. Cast Streaming's protocol design + // imposes a limit in the maximum difference between the highest-valued + // in-flight FrameId and the least-valued one. + REACHED_ID_SPAN_LIMIT, + + // Too-large a media duration is in-flight. Enqueuing another frame would + // automatically cause late play-out at the Receiver. + MAX_DURATION_IN_FLIGHT, + }; + + // Constructs a Sender that attaches to the given |environment|-provided + // resources and |packet_router|. The |config| contains the settings that were + // agreed-upon by both sides from the OFFER/ANSWER exchange (i.e., the part of + // the overall end-to-end connection process that occurs before Cast Streaming + // is started). The |rtp_payload_type| does not affect the behavior of this + // Sender. It is simply passed along to a Receiver in the RTP packet stream. + Sender(Environment* environment, + SenderPacketRouter* packet_router, + const SessionConfig& config, + RtpPayloadType rtp_payload_type); + + ~Sender() final; + + Ssrc ssrc() const { return rtcp_session_.sender_ssrc(); } + int rtp_timebase() const { return rtp_timebase_; } + + // Sets an observer for receiving notifications. Call with nullptr to stop + // observing. + void SetObserver(Observer* observer); + + // Returns the number of frames currently in-flight. This is only meant to be + // informative. Clients should use GetInFlightMediaDuration() to make + // throttling decisions. + int GetInFlightFrameCount() const; + + // Returns the total media duration of the frames currently in-flight, + // assuming the next not-yet-enqueued frame will have the given RTP timestamp. + // For a better user experience, the result should be compared to + // GetMaxInFlightMediaDuration(), and media encoding should be throttled down + // before additional EnqueueFrame() calls would cause this to reach the + // current maximum limit. + Clock::duration GetInFlightMediaDuration( + RtpTimeTicks next_frame_rtp_timestamp) const; + + // Return the maximum acceptable in-flight media duration, given the current + // target playout delay setting and end-to-end network/system conditions. + Clock::duration GetMaxInFlightMediaDuration() const; + + // Returns true if the Receiver requires a key frame. Note that this will + // return true until a key frame is accepted by EnqueueFrame(). Thus, when + // encoding is pipelined, care should be taken to instruct the encoder to + // produce just ONE forced key frame. + bool NeedsKeyFrame() const; + + // Returns the next FrameId, the one after the frame enqueued by the last call + // to EnqueueFrame(). Note that the next call to EnqueueFrame() assumes this + // frame ID be used. + FrameId GetNextFrameId() const; + + // Enqueues the given |frame| for sending as soon as possible. Returns OK if + // the frame is accepted, and some time later Observer::OnFrameCanceled() will + // be called once it is no longer in-flight. + // + // All fields of the |frame| must be set to valid values: the |frame_id| must + // be the same as GetNextFrameId(); both the |rtp_timestamp| and + // |reference_time| fields must be monotonically increasing relative to the + // prior frame; and the frame's |data| pointer must be set. + [[nodiscard]] EnqueueFrameResult EnqueueFrame(const EncodedFrame& frame); + + private: + // Tracking/Storage for frames that are ready-to-send, and until they are + // fully received at the other end. + struct PendingFrameSlot { + // The frame to send, or nullopt if this slot is not in use. + absl::optional<EncryptedFrame> frame; + + // Represents which packets need to be sent. Elements are indexed by + // FramePacketId. A set bit means a packet needs to be sent (or re-sent). + YetAnotherBitVector send_flags; + + // The time when each of the packets was last sent, or + // |SenderPacketRouter::kNever| if the packet has not been sent yet. + // Elements are indexed by FramePacketId. This is used to avoid + // re-transmitting any given packet too frequently. + std::vector<Clock::time_point> packet_sent_times; + + PendingFrameSlot(); + ~PendingFrameSlot(); + + bool is_active_for_frame(FrameId frame_id) const { + return frame && frame->frame_id == frame_id; + } + }; + + // Return value from the ChooseXYZ() helper methods. + struct ChosenPacket { + PendingFrameSlot* slot = nullptr; + FramePacketId packet_id{}; + + explicit operator bool() const { return !!slot; } + }; + + // An extension of ChosenPacket that also includes the point-in-time when the + // packet should be sent. + struct ChosenPacketAndWhen : public ChosenPacket { + Clock::time_point when = SenderPacketRouter::kNever; + }; + + // SenderPacketRouter::Sender implementation. + void OnReceivedRtcpPacket(Clock::time_point arrival_time, + absl::Span<const uint8_t> packet) final; + absl::Span<uint8_t> GetRtcpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) final; + absl::Span<uint8_t> GetRtpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) final; + Clock::time_point GetRtpResumeTime() final; + + // CompoundRtcpParser::Client implementation. + void OnReceiverReferenceTimeAdvanced(Clock::time_point reference_time) final; + void OnReceiverReport(const RtcpReportBlock& receiver_report) final; + void OnReceiverIndicatesPictureLoss() final; + void OnReceiverCheckpoint(FrameId frame_id, + std::chrono::milliseconds playout_delay) final; + void OnReceiverHasFrames(std::vector<FrameId> acks) final; + void OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) final; + + // Helper to choose which packet to send, from those that have been flagged as + // "need to send." Returns a "false" result if nothing needs to be sent. + ChosenPacket ChooseNextRtpPacketNeedingSend(); + + // Helper that returns the packet that should be used to kick-start the + // Receiver, and the time at which the packet should be sent. Returns a kNever + // result if kick-starting is not needed. + ChosenPacketAndWhen ChooseKickstartPacket(); + + // Cancels the given frame once it is known to have been fully received (i.e., + // based on the ACK feedback from the Receiver in a RTCP packet). This clears + // the corresponding entry in |pending_frames_| and notifies the Observer. + void CancelPendingFrame(FrameId frame_id); + + // Inline helper to return the slot that would contain the tracking info for + // the given |frame_id|. + const PendingFrameSlot* get_slot_for(FrameId frame_id) const { + return &pending_frames_[(frame_id - FrameId::first()) % + pending_frames_.size()]; + } + PendingFrameSlot* get_slot_for(FrameId frame_id) { + return &pending_frames_[(frame_id - FrameId::first()) % + pending_frames_.size()]; + } + + SenderPacketRouter* const packet_router_; + RtcpSession rtcp_session_; + CompoundRtcpParser rtcp_parser_; + SenderReportBuilder sender_report_builder_; + RtpPacketizer rtp_packetizer_; + const int rtp_timebase_; + FrameCrypto crypto_; + + // Ring buffer of PendingFrameSlots. The frame having FrameId x will always + // be slotted at position x % pending_frames_.size(). Use get_slot_for() to + // access the correct slot for a given FrameId. + std::array<PendingFrameSlot, kMaxUnackedFrames> pending_frames_{}; + + // A count of the number of frames in-flight (i.e., the number of active + // entries in |pending_frames_|). + int num_frames_in_flight_ = 0; + + // The ID of the last frame enqueued. + FrameId last_enqueued_frame_id_ = FrameId::leader(); + + // Indicates that all of the packets for all frames up to and including this + // FrameId have been successfully received (or otherwise do not need to be + // re-transmitted). + FrameId checkpoint_frame_id_ = FrameId::leader(); + + // The ID of the latest frame the Receiver seems to be aware of. + FrameId latest_expected_frame_id_ = FrameId::leader(); + + // The target playout delay for the last-enqueued frame. This is auto-updated + // when a frame is enqueued that changes the delay. + std::chrono::milliseconds target_playout_delay_; + FrameId playout_delay_change_at_frame_id_ = FrameId::first(); + + // The exact arrival time of the last RTCP packet. + Clock::time_point rtcp_packet_arrival_time_ = SenderPacketRouter::kNever; + + // The near-term average round trip time. This is updated with each Sender + // Report → Receiver Report round trip. This is initially zero, indicating the + // round trip time has not been measured yet. + Clock::duration round_trip_time_{0}; + + // Maintain current stats in a Sender Report that is ready for sending at any + // time. This includes up-to-date lip-sync information, and packet and byte + // count stats. + RtcpSenderReport pending_sender_report_; + + // These are used to determine whether a key frame needs to be sent to the + // Receiver. When the Receiver provides a picture loss notification, the + // current checkpoint frame ID is stored in |picture_lost_at_frame_id_|. Then, + // while |last_enqueued_key_frame_id_| is less than or equal to + // |picture_lost_at_frame_id_|, the Sender knows it still needs to send a key + // frame to resolve the picture loss condition. In all other cases, the + // Receiver is either in a good state or is in the process of receiving the + // key frame that will make that happen. + FrameId picture_lost_at_frame_id_ = FrameId::leader(); + FrameId last_enqueued_key_frame_id_ = FrameId::leader(); + + // The current observer (optional). + Observer* observer_ = nullptr; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_SENDER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc new file mode 100644 index 00000000000..870b085a861 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc @@ -0,0 +1,274 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/sender_packet_router.h" + +#include <algorithm> +#include <utility> + +#include "cast/streaming/constants.h" +#include "cast/streaming/packet_util.h" +#include "util/logging.h" +#include "util/saturate_cast.h" +#include "util/stringprintf.h" + +namespace openscreen { +namespace cast { + +using std::chrono::duration_cast; +using std::chrono::milliseconds; +using std::chrono::seconds; + +SenderPacketRouter::SenderPacketRouter(Environment* environment, + int max_burst_bitrate) + : SenderPacketRouter( + environment, + ComputeMaxPacketsPerBurst(max_burst_bitrate, + environment->GetMaxPacketSize(), + kDefaultBurstInterval), + kDefaultBurstInterval) {} + +SenderPacketRouter::SenderPacketRouter(Environment* environment, + int max_packets_per_burst, + milliseconds burst_interval) + : BandwidthEstimator(max_packets_per_burst, + burst_interval, + environment->now()), + environment_(environment), + packet_buffer_size_(environment->GetMaxPacketSize()), + packet_buffer_(new uint8_t[packet_buffer_size_]), + max_packets_per_burst_(max_packets_per_burst), + burst_interval_(burst_interval), + max_burst_bitrate_(ComputeMaxBurstBitrate(packet_buffer_size_, + max_packets_per_burst_, + burst_interval_)), + alarm_(environment_->now_function(), environment_->task_runner()) { + OSP_DCHECK(environment_); + OSP_DCHECK_GT(packet_buffer_size_, kRequiredNetworkPacketSize); +} + +SenderPacketRouter::~SenderPacketRouter() { + OSP_DCHECK(senders_.empty()); +} + +void SenderPacketRouter::OnSenderCreated(Ssrc receiver_ssrc, Sender* sender) { + OSP_DCHECK(FindEntry(receiver_ssrc) == senders_.end()); + senders_.push_back(SenderEntry{receiver_ssrc, sender, kNever, kNever}); + + if (senders_.size() == 1) { + environment_->ConsumeIncomingPackets(this); + } else { + // Sort the list of Senders so that they are iterated in priority order. + std::sort(senders_.begin(), senders_.end()); + } +} + +void SenderPacketRouter::OnSenderDestroyed(Ssrc receiver_ssrc) { + const auto it = FindEntry(receiver_ssrc); + OSP_DCHECK(it != senders_.end()); + senders_.erase(it); + + // If there are no longer any Senders, suspend receiving RTCP packets. + if (senders_.empty()) { + environment_->DropIncomingPackets(); + } +} + +void SenderPacketRouter::RequestRtcpSend(Ssrc receiver_ssrc) { + const auto it = FindEntry(receiver_ssrc); + OSP_DCHECK(it != senders_.end()); + it->next_rtcp_send_time = Alarm::kImmediately; + ScheduleNextBurst(); +} + +void SenderPacketRouter::RequestRtpSend(Ssrc receiver_ssrc) { + const auto it = FindEntry(receiver_ssrc); + OSP_DCHECK(it != senders_.end()); + it->next_rtp_send_time = Alarm::kImmediately; + ScheduleNextBurst(); +} + +void SenderPacketRouter::OnReceivedPacket(const IPEndpoint& source, + Clock::time_point arrival_time, + std::vector<uint8_t> packet) { + // If the packet did not come from the expected endpoint, ignore it. + OSP_DCHECK_NE(source.port, uint16_t{0}); + if (source != environment_->remote_endpoint()) { + return; + } + + // Determine which Sender to dispatch the packet to. Senders may only receive + // RTCP packets from Receivers. Log a warning containing a pretty-printed dump + // if the packet is not an RTCP packet. + const std::pair<ApparentPacketType, Ssrc> seems_like = + InspectPacketForRouting(packet); + if (seems_like.first != ApparentPacketType::RTCP) { + constexpr int kMaxPartiaHexDumpSize = 96; + OSP_LOG_WARN << "UNKNOWN packet of " << packet.size() + << " bytes. Partial hex dump: " + << HexEncode(absl::Span<const uint8_t>(packet).subspan( + 0, kMaxPartiaHexDumpSize)); + return; + } + const auto it = FindEntry(seems_like.second); + if (it != senders_.end()) { + it->sender->OnReceivedRtcpPacket(arrival_time, std::move(packet)); + } +} + +SenderPacketRouter::SenderEntries::iterator SenderPacketRouter::FindEntry( + Ssrc receiver_ssrc) { + return std::find_if(senders_.begin(), senders_.end(), + [receiver_ssrc](const SenderEntry& entry) { + return entry.receiver_ssrc == receiver_ssrc; + }); +} + +void SenderPacketRouter::ScheduleNextBurst() { + // Determine the next burst time by scanning for the earliest of the + // next-scheduled send times for each Sender. + const Clock::time_point earliest_allowed_burst_time = + last_burst_time_ + burst_interval_; + Clock::time_point next_burst_time = kNever; + for (const SenderEntry& entry : senders_) { + const auto next_send_time = + std::min(entry.next_rtcp_send_time, entry.next_rtp_send_time); + if (next_send_time >= next_burst_time) { + continue; + } + if (next_send_time <= earliest_allowed_burst_time) { + next_burst_time = earliest_allowed_burst_time; + // No need to continue, since |next_burst_time| cannot become any earlier. + break; + } + next_burst_time = next_send_time; + } + + // Schedule the alarm for the next burst time unless none of the Senders has + // anything to send. + if (next_burst_time == kNever) { + alarm_.Cancel(); + } else { + alarm_.Schedule([this] { SendBurstOfPackets(); }, next_burst_time); + } +} + +void SenderPacketRouter::SendBurstOfPackets() { + // Treat RTCP packets as "critical priority," and so there is no upper limit + // on the number to send. Practically, this will always be limited by the + // number of Senders; so, this won't be a huge number of packets. + const Clock::time_point burst_time = environment_->now(); + const int num_rtcp_packets_sent = SendJustTheRtcpPackets(burst_time); + // Now send all the RTP packets, up to the maximum number allowed in a burst. + // Higher priority Senders' RTP packets are sent first. + const int num_rtp_packets_sent = SendJustTheRtpPackets( + burst_time, max_packets_per_burst_ - num_rtcp_packets_sent); + last_burst_time_ = burst_time; + + BandwidthEstimator::OnBurstComplete( + num_rtcp_packets_sent + num_rtp_packets_sent, burst_time); + + ScheduleNextBurst(); +} + +int SenderPacketRouter::SendJustTheRtcpPackets(Clock::time_point send_time) { + int num_sent = 0; + for (SenderEntry& entry : senders_) { + if (entry.next_rtcp_send_time > send_time) { + continue; + } + + // Note: Only one RTCP packet is sent from the same Sender in the same + // burst. This is because RTCP packets are supposed to always contain the + // most up-to-date Sender state. Having multiple RTCP packets in the same + // burst would mean that all but the last one are old/irrelevant snapshots + // of Sender state, and this would just thrash/confuse the Receiver. + const absl::Span<uint8_t> packet = + entry.sender->GetRtcpPacketForImmediateSend( + send_time, + absl::Span<uint8_t>(packet_buffer_.get(), packet_buffer_size_)); + if (!packet.empty()) { + environment_->SendPacket(packet); + entry.next_rtcp_send_time = send_time + kRtcpReportInterval; + ++num_sent; + } + } + + return num_sent; +} + +int SenderPacketRouter::SendJustTheRtpPackets(Clock::time_point send_time, + int num_packets_to_send) { + int num_sent = 0; + for (SenderEntry& entry : senders_) { + if (num_sent >= num_packets_to_send) { + break; + } + if (entry.next_rtp_send_time > send_time) { + continue; + } + + for (; num_sent < num_packets_to_send; ++num_sent) { + const absl::Span<uint8_t> packet = + entry.sender->GetRtpPacketForImmediateSend( + send_time, + absl::Span<uint8_t>(packet_buffer_.get(), packet_buffer_size_)); + if (packet.empty()) { + break; + } + environment_->SendPacket(packet); + } + entry.next_rtp_send_time = entry.sender->GetRtpResumeTime(); + } + + return num_sent; +} + +namespace { +constexpr int kBitsPerByte = 8; +constexpr auto kOneSecondInMilliseconds = + duration_cast<milliseconds>(seconds(1)); +} // namespace + +// static +int SenderPacketRouter::ComputeMaxPacketsPerBurst(int max_burst_bitrate, + int packet_size, + milliseconds burst_interval) { + OSP_DCHECK_GT(max_burst_bitrate, 0); + OSP_DCHECK_GT(packet_size, 0); + OSP_DCHECK_GT(burst_interval, milliseconds(0)); + OSP_DCHECK_LE(burst_interval, kOneSecondInMilliseconds); + + const int max_packets_per_second = + max_burst_bitrate / kBitsPerByte / packet_size; + const int bursts_per_second = kOneSecondInMilliseconds / burst_interval; + return std::max(max_packets_per_second / bursts_per_second, 1); +} + +// static +int SenderPacketRouter::ComputeMaxBurstBitrate(int packet_size, + int max_packets_per_burst, + milliseconds burst_interval) { + OSP_DCHECK_GT(packet_size, 0); + OSP_DCHECK_GT(max_packets_per_burst, 0); + OSP_DCHECK_GT(burst_interval, milliseconds(0)); + OSP_DCHECK_LE(burst_interval, kOneSecondInMilliseconds); + + const int64_t max_bits_per_burst = + int64_t{packet_size} * kBitsPerByte * max_packets_per_burst; + const int bursts_per_second = kOneSecondInMilliseconds / burst_interval; + return saturate_cast<int>(max_bits_per_burst * bursts_per_second); +} + +SenderPacketRouter::Sender::~Sender() = default; + +// static +constexpr int SenderPacketRouter::kDefaultMaxBurstBitrate; +// static +constexpr milliseconds SenderPacketRouter::kDefaultBurstInterval; +// static +constexpr Clock::time_point SenderPacketRouter::kNever; + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h new file mode 100644 index 00000000000..e73e73eca3b --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h @@ -0,0 +1,196 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_SENDER_PACKET_ROUTER_H_ +#define CAST_STREAMING_SENDER_PACKET_ROUTER_H_ + +#include <stdint.h> + +#include <chrono> // NOLINT +#include <memory> +#include <vector> + +#include "absl/types/span.h" +#include "cast/streaming/bandwidth_estimator.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/ssrc.h" +#include "platform/api/time.h" +#include "util/alarm.h" + +namespace openscreen { +namespace cast { + +// Manages network packet transmission for one or more Senders, directing each +// inbound packet to a specific Sender instance, pacing the transmission of +// outbound packets, and employing network bandwidth/availability monitoring and +// congestion control. +// +// Instead of just sending packets whenever they want, Senders must request +// transmission from the SenderPacketRouter. The router then calls-back to each +// Sender, in the near future, when it has allocated an available time slice for +// transmission. The Sender is allowed to decide, at that exact moment, which +// packet most needs to be sent. +// +// Pacing strategy: Packets are sent in bursts. This allows the platform +// (operating system) to collect many small packets into a short-term buffer, +// which allows for optimizations at the link layer. For example, multiple +// packets can be sent together as one larger transmission unit, and this can be +// critical for good performance over shared-medium networks (such as 802.11 +// WiFi). https://en.wikipedia.org/wiki/Frame-bursting +class SenderPacketRouter : public BandwidthEstimator, + public Environment::PacketConsumer { + public: + class Sender { + public: + // Called to provide the Sender with what looks like a RTCP packet meant for + // it specifically (among other Senders) to process. |arrival_time| + // indicates when the packet arrived (i.e., when it was received from the + // platform). + virtual void OnReceivedRtcpPacket(Clock::time_point arrival_time, + absl::Span<const uint8_t> packet) = 0; + + // Populates the given |buffer| with a RTCP/RTP packet that will be sent + // immediately. Returns the portion of |buffer| contaning the packet, or an + // empty Span if nothing is ready to send. + virtual absl::Span<uint8_t> GetRtcpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) = 0; + virtual absl::Span<uint8_t> GetRtpPacketForImmediateSend( + Clock::time_point send_time, + absl::Span<uint8_t> buffer) = 0; + + // Returns the point-in-time at which RTP sending should resume, or kNever + // if it should be suspended until an explicit call to RequestRtpSend(). The + // implementation may return a value on or before "now" to indicate an + // immediate resume is desired. + virtual Clock::time_point GetRtpResumeTime() = 0; + + protected: + virtual ~Sender(); + }; + + // Constructs an instance with default burst parameters appropriate for the + // given |max_burst_bitrate|. + explicit SenderPacketRouter(Environment* environment, + int max_burst_bitrate = kDefaultMaxBurstBitrate); + + // Constructs an instance with specific burst parameters. The maximum bitrate + // will be computed based on these (and Environment::GetMaxPacketSize()). + SenderPacketRouter(Environment* environment, + int max_packets_per_burst, + std::chrono::milliseconds burst_interval); + + ~SenderPacketRouter(); + + int max_packet_size() const { return packet_buffer_size_; } + int max_burst_bitrate() const { return max_burst_bitrate_; } + + // Called from a Sender constructor/destructor to register/deregister a Sender + // instance that processes RTP/RTCP packets from a Receiver having the given + // SSRC. + void OnSenderCreated(Ssrc receiver_ssrc, Sender* client); + void OnSenderDestroyed(Ssrc receiver_ssrc); + + // Requests an immediate send of a RTCP packet, and then RTCP sending will + // repeat at regular intervals (see kRtcpSendInterval) until the Sender is + // de-registered. + void RequestRtcpSend(Ssrc receiver_ssrc); + + // Requests an immediate send of a RTP packet. RTP sending will continue until + // the Sender stops providing packet data. + // + // See also: Sender::GetRtpResumeTime(). + void RequestRtpSend(Ssrc receiver_ssrc); + + // A reasonable default maximum bitrate for bursting. Congestion control + // should always be employed to limit the Senders' sustained/average outbound + // data volume for "fair" use of the network. + static constexpr int kDefaultMaxBurstBitrate = 24 << 20; // 24 megabits/sec + + // The minimum amount of time between burst-sends. The methodology by which + // this value was determined is lost knowledge, but is likely the result of + // experimentation with various network and operating system configurations. + // This value came from the original Chrome Cast Streaming implementation. + static constexpr std::chrono::milliseconds kDefaultBurstInterval{10}; + + // A special time_point value representing "never." + static constexpr Clock::time_point kNever = Clock::time_point::max(); + + private: + struct SenderEntry { + Ssrc receiver_ssrc; + Sender* sender; + Clock::time_point next_rtcp_send_time; + Clock::time_point next_rtp_send_time; + + // Entries are ordered by the transmission priority (high→low), as implied + // by their SSRC. See ssrc.h for details. + bool operator<(const SenderEntry& other) const { + return ComparePriority(receiver_ssrc, other.receiver_ssrc) < 0; + } + }; + + using SenderEntries = std::vector<SenderEntry>; + + // Environment::PacketConsumer implementation. + void OnReceivedPacket(const IPEndpoint& source, + Clock::time_point arrival_time, + std::vector<uint8_t> packet) final; + + // Helper to return an iterator pointing to the entry corresponding to the + // given |receiver_ssrc|, or "end" if not found. + SenderEntries::iterator FindEntry(Ssrc receiver_ssrc); + + // Examine the next send time for all Senders, and decide whether to schedule + // a burst-send. + void ScheduleNextBurst(); + + // Performs a burst-send of packets. This is called whevener the Alarm fires. + void SendBurstOfPackets(); + + // Send an RTCP packet from each Sender that has one ready, and return the + // number of packets sent. + int SendJustTheRtcpPackets(Clock::time_point send_time); + + // Send zero or more RTP packets from each Sender, up to a maximum of + // |num_packets_to_send|, and return the number of packets sent. + int SendJustTheRtpPackets(Clock::time_point send_time, + int num_packets_to_send); + + // Returns the maximum number of packets to send in one burst, based on the + // given parameters. + static int ComputeMaxPacketsPerBurst( + int max_burst_bitrate, + int packet_size, + std::chrono::milliseconds burst_interval); + + // Returns the maximum bitrate inferred by the given parameters. + static int ComputeMaxBurstBitrate(int packet_size, + int max_packets_per_burst, + std::chrono::milliseconds burst_interval); + + Environment* const environment_; + const int packet_buffer_size_; + const std::unique_ptr<uint8_t[]> packet_buffer_; + const int max_packets_per_burst_; + const std::chrono::milliseconds burst_interval_; + const int max_burst_bitrate_; + + // Schedules the task that calls back into this SenderPacketRouter at a later + // time to send the next burst of packets. + Alarm alarm_; + + // The current list of Senders and their timing information. This is + // maintained in order of the priority implied by the Sender SSRC's. + SenderEntries senders_; + + // The last time a burst of packets was sent. This is used to determine the + // next burst time. + Clock::time_point last_burst_time_ = Clock::time_point::min(); +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_SENDER_PACKET_ROUTER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc new file mode 100644 index 00000000000..75196bdbe68 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc @@ -0,0 +1,616 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/sender_packet_router.h" + +#include "cast/streaming/constants.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/base/ip_address.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "util/big_endian.h" +#include "util/logging.h" + +using std::chrono::milliseconds; +using std::chrono::seconds; + +using testing::_; +using testing::Invoke; +using testing::Mock; +using testing::Return; + +namespace openscreen { +namespace cast { +namespace { + +const IPEndpoint kRemoteEndpoint{ + // Use a random IPv6 address in the range reserved for "documentation + // purposes." + IPAddress::Parse("2001:db8:0d93:69c2:fd1a:49a6:a7c0:e8a6").value(), 25476}; + +const IPEndpoint kUnexpectedEndpoint{ + IPAddress::Parse("2001:db8:0d93:69c2:fd1a:49a6:a7c0:e8a7").value(), 25476}; + +// Limited burst parameters to simplify unit testing. +constexpr int kMaxPacketsPerBurst = 3; +constexpr auto kBurstInterval = milliseconds(10); + +constexpr Ssrc kAudioReceiverSsrc = 2; +constexpr Ssrc kVideoReceiverSsrc = 32; + +const uint8_t kGarbagePacket[] = { + 0x42, 0x61, 0x16, 0x17, 0x26, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, + 0x6e, 0x67, 0x2f, 0x63, 0x61, 0x73, 0x74, 0x2f, 0x63, 0x6f, 0x6d, 0x70, + 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x72, 0x74, 0x63, 0x70, 0x5f}; + +// clang-format off +const uint8_t kValidAudioRtcpPacket[] = { + 0b10000000, // Version=2, Padding=no, ReportCount=0. + 201, // RTCP Packet type byte. + 0x00, 0x01, // Length of remainder of packet, in 32-bit words. + 0x00, 0x00, 0x00, 0x02, // Receiver SSRC. +}; + +const uint8_t kValidAudioRtpPacket[] = { + 0b10000000, // Version/Padding byte. + 96, // Payload type byte. + 0xbe, 0xef, // Sequence number. + 9, 8, 7, 6, // RTP timestamp. + 0, 0, 0, 2, // SSRC. + 0b10000000, // Is key frame, no extensions. + 5, // Frame ID. + 0xa, 0xb, // Packet ID. + 0xa, 0xc, // Max packet ID. + 0xf, 0xe, 0xd, 0xc, 0xb, 0xa, 0x9, 0x8, // Payload. +}; +// clang-format on + +// Returns a copy of an |original| RTCP packet, but with its send-to SSRC +// modified to the given |alternate_ssrc|. +std::vector<uint8_t> MakeRtcpPacketWithAlternateReceiverSsrc( + absl::Span<const uint8_t> original, + Ssrc alternate_ssrc) { + constexpr int kOffsetToSsrcField = 4; + std::vector<uint8_t> out(original.begin(), original.end()); + OSP_CHECK_GE(out.size(), kOffsetToSsrcField + sizeof(uint32_t)); + WriteBigEndian(uint32_t{alternate_ssrc}, out.data() + kOffsetToSsrcField); + return out; +} + +// Serializes the |flag| and |send_time| into the front of |buffer| so the tests +// can make unique packets and confirm their identities after passing through +// various components. +absl::Span<uint8_t> MakeFakePacketWithFlag(char flag, + Clock::time_point send_time, + absl::Span<uint8_t> buffer) { + const Clock::duration::rep ticks = send_time.time_since_epoch().count(); + const auto packet_size = sizeof(ticks) + sizeof(flag); + buffer = buffer.subspan(0, packet_size); + OSP_CHECK_EQ(buffer.size(), packet_size); + WriteBigEndian(ticks, buffer.data()); + buffer[sizeof(ticks)] = flag; + return buffer; +} + +// Same as MakeFakePacketWithFlag(), but for tests that don't use the flag. +absl::Span<uint8_t> MakeFakePacket(Clock::time_point send_time, + absl::Span<uint8_t> buffer) { + return MakeFakePacketWithFlag('?', send_time, buffer); +} + +// Returns the flag that was placed in the given |fake_packet|, or '?' if +// unknown. +char ParseFlag(absl::Span<const uint8_t> fake_packet) { + constexpr auto kFlagOffset = sizeof(Clock::duration::rep); + if (fake_packet.size() == (kFlagOffset + sizeof(char))) { + return static_cast<char>(fake_packet[kFlagOffset]); + } + return '?'; +} + +// Deserializes and returns the timestamp that was placed in the given |packet| +// by MakeFakePacketWithFlag(). +Clock::time_point ParseTimestamp(absl::Span<const uint8_t> fake_packet) { + Clock::duration::rep ticks = 0; + if (fake_packet.size() >= sizeof(ticks)) { + ticks = ReadBigEndian<Clock::duration::rep>(fake_packet.data()); + } + return Clock::time_point() + Clock::duration(ticks); +} + +// Returns an empty version of |buffer|. +absl::Span<uint8_t> ToEmptyPacketBuffer(Clock::time_point send_time, + absl::Span<uint8_t> buffer) { + return buffer.subspan(0, 0); +} + +class MockEnvironment : public Environment { + public: + MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner) + : Environment(now_function, task_runner) {} + + ~MockEnvironment() override = default; + + MOCK_METHOD1(SendPacket, void(absl::Span<const uint8_t> packet)); +}; + +class MockSender : public SenderPacketRouter::Sender { + public: + MockSender() = default; + ~MockSender() override = default; + + MOCK_METHOD(void, + OnReceivedRtcpPacket, + (Clock::time_point arrival_time, + absl::Span<const uint8_t> packet), + (override)); + MOCK_METHOD(absl::Span<uint8_t>, + GetRtcpPacketForImmediateSend, + (Clock::time_point send_time, absl::Span<uint8_t> buffer), + (override)); + MOCK_METHOD(absl::Span<uint8_t>, + GetRtpPacketForImmediateSend, + (Clock::time_point send_time, absl::Span<uint8_t> buffer), + (override)); + MOCK_METHOD(Clock::time_point, GetRtpResumeTime, (), (override)); +}; + +class SenderPacketRouterTest : public testing::Test { + public: + SenderPacketRouterTest() + : clock_(Clock::now()), + task_runner_(&clock_), + env_(&FakeClock::now, &task_runner_), + router_(&env_, kMaxPacketsPerBurst, kBurstInterval) { + env_.set_socket_error_handler( + [](Error error) { ASSERT_TRUE(error.ok()) << error; }); + } + + ~SenderPacketRouterTest() override = default; + + MockEnvironment* env() { return &env_; } + SenderPacketRouter* router() { return &router_; } + MockSender* audio_sender() { return &audio_sender_; } + MockSender* video_sender() { return &video_sender_; } + + void SimulatePacketArrivedNow(const IPEndpoint& source, + absl::Span<const uint8_t> packet) { + static_cast<Environment::PacketConsumer*>(&router_)->OnReceivedPacket( + source, env_.now(), std::vector<uint8_t>(packet.begin(), packet.end())); + } + + void AdvanceClockAndRunTasks(Clock::duration delta) { clock_.Advance(delta); } + void RunTasksUntilIdle() { task_runner_.RunTasksUntilIdle(); } + + private: + FakeClock clock_; + FakeTaskRunner task_runner_; + testing::NiceMock<MockEnvironment> env_; + SenderPacketRouter router_; + testing::NiceMock<MockSender> audio_sender_; + testing::NiceMock<MockSender> video_sender_; +}; + +// Tests that the SenderPacketRouter is correctly configured from the specific +// burst parameters that were passed to its constructor. This confirms internal +// calculations based on these parameters. +TEST_F(SenderPacketRouterTest, IsConfiguredFromBurstParameters) { + EXPECT_EQ(env()->GetMaxPacketSize(), router()->max_packet_size()); + + // The following lower-bound/upper-bound values were hand-calculated based on + // the arguments that were passed to the SenderPacketRouter constructor, and + // assuming a packet size anywhere from 256 bytes to one megabyte. + // + // The exact value for max_burst_bitrate() is not known here because + // Environment::GetMaxPacketSize() depends on the platform and network medium. + // To test for an exact value would require duplicating the math in + // SenderPacketRouter::ComputeMaxBurstBitrate() here (and then *what* would we + // be testing?). + EXPECT_LE(614400, router()->max_burst_bitrate()); + EXPECT_GE(2147483647, router()->max_burst_bitrate()); +} + +TEST_F(SenderPacketRouterTest, IgnoresPacketsFromUnexpectedSources) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + SimulatePacketArrivedNow(kUnexpectedEndpoint, + absl::Span<const uint8_t>(kValidAudioRtcpPacket)); + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +TEST_F(SenderPacketRouterTest, IgnoresInboundPacketsContainingGarbage) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + SimulatePacketArrivedNow(kUnexpectedEndpoint, + absl::Span<const uint8_t>(kGarbagePacket)); + SimulatePacketArrivedNow(kRemoteEndpoint, + absl::Span<const uint8_t>(kGarbagePacket)); + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +// Note: RTP packets should be ignored since it wouldn't make sense for a +// Receiver to stream media to a Sender. +TEST_F(SenderPacketRouterTest, IgnoresInboundRtpPackets) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + SimulatePacketArrivedNow(kUnexpectedEndpoint, + absl::Span<const uint8_t>(kValidAudioRtpPacket)); + SimulatePacketArrivedNow(kRemoteEndpoint, + absl::Span<const uint8_t>(kValidAudioRtpPacket)); + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +TEST_F(SenderPacketRouterTest, IgnoresInboundRtcpPacketsFromUnknownReceivers) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + const std::vector<uint8_t> rtcp_packet_not_for_me = + MakeRtcpPacketWithAlternateReceiverSsrc(kValidAudioRtcpPacket, + kAudioReceiverSsrc + 1); + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + SimulatePacketArrivedNow(kUnexpectedEndpoint, + absl::Span<const uint8_t>(rtcp_packet_not_for_me)); + SimulatePacketArrivedNow(kRemoteEndpoint, + absl::Span<const uint8_t>(rtcp_packet_not_for_me)); + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +// Tests that the SenderPacketRouter forwards packets from Receivers to the +// appropriate Sender. +TEST_F(SenderPacketRouterTest, RoutesRTCPPacketsFromReceivers) { + EXPECT_CALL(*env(), SendPacket(_)).Times(0); + + const absl::Span<const uint8_t> audio_rtcp_packet(kValidAudioRtcpPacket); + std::vector<uint8_t> video_rtcp_packet = + MakeRtcpPacketWithAlternateReceiverSsrc(audio_rtcp_packet, + kVideoReceiverSsrc); + + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + + // It should route a valid audio RTCP packet to the audio Sender, and ignore a + // valid video RTCP packet (since the video Sender is not yet known to the + // SenderPacketRouter). + { + Clock::time_point arrival_time{}; + std::vector<uint8_t> received_packet; + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)) + .WillOnce(Invoke( + [&](Clock::time_point when, absl::Span<const uint8_t> packet) { + arrival_time = when; + received_packet.assign(packet.begin(), packet.end()); + })); + EXPECT_CALL(*video_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + + const Clock::time_point expected_arrival_time = env()->now(); + SimulatePacketArrivedNow(kRemoteEndpoint, audio_rtcp_packet); + SimulatePacketArrivedNow(kRemoteEndpoint, video_rtcp_packet); + + Mock::VerifyAndClear(audio_sender()); + EXPECT_EQ(expected_arrival_time, arrival_time); + EXPECT_EQ(audio_rtcp_packet, received_packet); + + Mock::VerifyAndClear(video_sender()); + } + + AdvanceClockAndRunTasks(seconds(1)); + + // Register the video Sender with the router. Now, confirm audio RTCP packets + // still go to the audio Sender and video RTCP packets go to the video Sender. + router()->OnSenderCreated(kVideoReceiverSsrc, video_sender()); + { + Clock::time_point audio_arrival_time{}, video_arrival_time{}; + std::vector<uint8_t> received_audio_packet, received_video_packet; + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)) + .WillOnce(Invoke( + [&](Clock::time_point when, absl::Span<const uint8_t> packet) { + audio_arrival_time = when; + received_audio_packet.assign(packet.begin(), packet.end()); + })); + EXPECT_CALL(*video_sender(), OnReceivedRtcpPacket(_, _)) + .WillOnce(Invoke( + [&](Clock::time_point when, absl::Span<const uint8_t> packet) { + video_arrival_time = when; + received_video_packet.assign(packet.begin(), packet.end()); + })); + + const Clock::time_point expected_audio_arrival_time = env()->now(); + SimulatePacketArrivedNow(kRemoteEndpoint, audio_rtcp_packet); + + AdvanceClockAndRunTasks(milliseconds(11)); + + const Clock::time_point expected_video_arrival_time = env()->now(); + SimulatePacketArrivedNow(kRemoteEndpoint, video_rtcp_packet); + + Mock::VerifyAndClear(audio_sender()); + EXPECT_EQ(expected_audio_arrival_time, audio_arrival_time); + EXPECT_EQ(audio_rtcp_packet, received_audio_packet); + + Mock::VerifyAndClear(video_sender()); + EXPECT_EQ(expected_video_arrival_time, video_arrival_time); + EXPECT_EQ(video_rtcp_packet, received_video_packet); + } + + router()->OnSenderDestroyed(kAudioReceiverSsrc); + router()->OnSenderDestroyed(kVideoReceiverSsrc); +} + +// Tests that the SenderPacketRouter schedules periodic RTCP packet sends, +// starting once the Sender requests the first RTCP send. +TEST_F(SenderPacketRouterTest, SchedulesPeriodicTransmissionOfRTCPPackets) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + + constexpr int kNumIterations = 5; + + EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0); + EXPECT_CALL(*audio_sender(), GetRtcpPacketForImmediateSend(_, _)) + .Times(kNumIterations) + .WillRepeatedly(Invoke(&MakeFakePacket)); + EXPECT_CALL(*audio_sender(), GetRtpPacketForImmediateSend(_, _)).Times(0); + ON_CALL(*audio_sender(), GetRtpResumeTime()) + .WillByDefault(Return(SenderPacketRouter::kNever)); + + // Capture every packet sent for analysis at the end of this test. + std::vector<std::vector<uint8_t>> packets_sent; + EXPECT_CALL(*env(), SendPacket(_)) + .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) { + packets_sent.emplace_back(packet.begin(), packet.end()); + })); + + const Clock::time_point first_send_time = env()->now(); + router()->RequestRtcpSend(kAudioReceiverSsrc); + RunTasksUntilIdle(); // The first RTCP packet should be sent immediately. + for (int i = 1; i < kNumIterations; ++i) { + AdvanceClockAndRunTasks(kRtcpReportInterval); + } + + // Ensure each RTCP packet was sent and in-sequence. + Mock::VerifyAndClear(env()); + ASSERT_EQ(kNumIterations, static_cast<int>(packets_sent.size())); + for (int i = 0; i < kNumIterations; ++i) { + const Clock::time_point expected_send_time = + first_send_time + i * kRtcpReportInterval; + EXPECT_EQ(expected_send_time, ParseTimestamp(packets_sent[i])); + } + + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +// Tests that the SenderPacketRouter schedules RTP packet bursts from a single +// Sender. +TEST_F(SenderPacketRouterTest, SchedulesAndTransmitsRTPBursts) { + env()->set_remote_endpoint(kRemoteEndpoint); + router()->OnSenderCreated(kVideoReceiverSsrc, video_sender()); + + // Capture every packet sent for analysis at the end of this test. + std::vector<std::vector<uint8_t>> packets_sent; + EXPECT_CALL(*env(), SendPacket(_)) + .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) { + packets_sent.emplace_back(packet.begin(), packet.end()); + })); + + // Simulate a typical video Sender RTP at-startup sending sequence: First, at + // t=0ms, the Sender wants to send its large 10-packet key frame. This will + // require four bursts, since only 3 packets can be sent per burst. + // + // While the first frame is being sent, a smaller 4-packet frame is enqueued, + // and the Sender will want to start sending this immediately after the first + // frame. Part of this second frame will be sent in the fourth burst, and the + // rest in the fifth burst. + // + // After the fifth burst, the Sender will schedule a "kickstart packet" for + // 25ms later. However, when the SenderPacketRouter later asks the Sender for + // that packet, the Sender will change its mind and decide not to send + // anything. + // + // At t=100ms, the next frame of video is enqueued in the Sender and it + // requests that RTP sending resume for that. This is a small 1-packet frame. + const Clock::time_point start_time = env()->now(); + int num_get_rtp_calls = 0; + EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _)) + .Times(14 + 2) + .WillRepeatedly( + Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) { + ++num_get_rtp_calls; + + // 14 packets are sent: The first through fourth bursts send three + // packets each, and the fifth burst sends two. + if (num_get_rtp_calls <= 14) { + return MakeFakePacket(send_time, buffer); + } + + // 2 "done signals" are then sent: One is at the end of the fifth + // burst, one is for a "nothing to send" sixth burst. + return ToEmptyPacketBuffer(send_time, buffer); + })); + const Clock::time_point kickstart_time = + start_time + 4 * kBurstInterval + milliseconds(25); + int num_get_resume_calls = 0; + EXPECT_CALL(*video_sender(), GetRtpResumeTime()) + .Times(4 + 1 + 1) + .WillRepeatedly(Invoke([&] { + ++num_get_resume_calls; + + // After each of the first through fourth bursts, the Sender wants to + // transmit more right away. + if (num_get_resume_calls <= 4) { + return env()->now(); + } + + // After the fifth burst, the Sender requests resuming for kickstart + // later. + if (num_get_resume_calls == 5) { + return kickstart_time; + } + + // After the sixth burst, the Sender pauses RTP sending indefinitely. + return SenderPacketRouter::kNever; + })); + router()->RequestRtpSend(kVideoReceiverSsrc); + // Execute first burst. + RunTasksUntilIdle(); + // Execute second through fifth bursts. + for (int i = 1; i <= 4; ++i) { + AdvanceClockAndRunTasks(kBurstInterval); + } + // Execute the sixth burst at the kickstart time. + AdvanceClockAndRunTasks(kickstart_time - env()->now()); + Mock::VerifyAndClear(video_sender()); + + // Now, resume RTP sending for one more 1-packet frame, and then pause RTP + // sending again. + EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _)) + .WillOnce(Invoke(&MakeFakePacket)) // Frame 2, only packet. + .WillOnce(Invoke(&ToEmptyPacketBuffer)); // Done for now. + // After the seventh burst, the Sender pauses RTP sending again. + EXPECT_CALL(*video_sender(), GetRtpResumeTime()) + .WillOnce(Return(SenderPacketRouter::kNever)); + // Advance to the resume time. Nothing should happen until RequestRtpSend() is + // called. + const Clock::time_point resume_time = start_time + milliseconds(100); + AdvanceClockAndRunTasks(resume_time - env()->now()); + router()->RequestRtpSend(kVideoReceiverSsrc); + // Execute seventh burst. + RunTasksUntilIdle(); + // Run for one more second, but nothing should be happening since sending is + // paused. + AdvanceClockAndRunTasks(seconds(1)); + Mock::VerifyAndClear(video_sender()); + + // Confirm 15 packets got sent and contain the expected data (which tracks + // when they were sent). + ASSERT_EQ(15, static_cast<int>(packets_sent.size())); + Clock::time_point expected_time; + int packet_idx = 0; + // First burst through fourth burst. + for (int burst_number = 0; burst_number < 4; ++burst_number) { + expected_time = start_time + burst_number * kBurstInterval; + EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++])); + EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++])); + EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++])); + } + // Fifth burst. + expected_time += kBurstInterval; + EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++])); + EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++])); + // Seventh burst (sixth burst sent nothing). + EXPECT_EQ(resume_time, ParseTimestamp(packets_sent[packet_idx++])); + + router()->OnSenderDestroyed(kVideoReceiverSsrc); +} + +// Tests that the SenderPacketRouter schedules packet sends based on transmit +// prority: RTCP before RTP, and the audio Sender's packets before the video +// Sender's. +TEST_F(SenderPacketRouterTest, SchedulesAndTransmitsAccountingForPriority) { + env()->set_remote_endpoint(kRemoteEndpoint); + ASSERT_LT(ComparePriority(kAudioReceiverSsrc, kVideoReceiverSsrc), 0); + router()->OnSenderCreated(kVideoReceiverSsrc, video_sender()); + router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender()); + + // Capture every packet sent for analysis at the end of this test. + std::vector<std::vector<uint8_t>> packets_sent; + EXPECT_CALL(*env(), SendPacket(_)) + .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) { + packets_sent.emplace_back(packet.begin(), packet.end()); + })); + + // These indicate how often one packet will be sent from each Sender. + constexpr Clock::duration kAudioRtpInterval = milliseconds(10); + constexpr Clock::duration kVideoRtpInterval = milliseconds(33); + + // Note: The priority flags used in this test ('0'..'3') indicate + // lowest-to-highest priority. + EXPECT_CALL(*audio_sender(), GetRtcpPacketForImmediateSend(_, _)) + .WillRepeatedly( + Invoke([](Clock::time_point send_time, absl::Span<uint8_t> buffer) { + return MakeFakePacketWithFlag('3', send_time, buffer); + })); + int num_audio_get_rtp_calls = 0; + EXPECT_CALL(*audio_sender(), GetRtpPacketForImmediateSend(_, _)) + .WillRepeatedly( + Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) { + // Alternate between returning a single packet and a "done for now" + // signal. + ++num_audio_get_rtp_calls; + if (num_audio_get_rtp_calls % 2) { + return MakeFakePacketWithFlag('1', send_time, buffer); + } + return buffer.subspan(0, 0); + })); + EXPECT_CALL(*video_sender(), GetRtcpPacketForImmediateSend(_, _)) + .WillRepeatedly( + Invoke([](Clock::time_point send_time, absl::Span<uint8_t> buffer) { + return MakeFakePacketWithFlag('2', send_time, buffer); + })); + int num_video_get_rtp_calls = 0; + EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _)) + .WillRepeatedly( + Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) { + // Alternate between returning a single packet and a "done for now" + // signal. + ++num_video_get_rtp_calls; + if (num_video_get_rtp_calls % 2) { + return MakeFakePacketWithFlag('0', send_time, buffer); + } + return buffer.subspan(0, 0); + })); + EXPECT_CALL(*audio_sender(), GetRtpResumeTime()).WillRepeatedly(Invoke([&] { + return env()->now() + kAudioRtpInterval; + })); + EXPECT_CALL(*video_sender(), GetRtpResumeTime()).WillRepeatedly(Invoke([&] { + return env()->now() + kVideoRtpInterval; + })); + + // Request starting both RTCP and RTP sends for both Senders, in a random + // order. + router()->RequestRtcpSend(kVideoReceiverSsrc); + router()->RequestRtpSend(kAudioReceiverSsrc); + router()->RequestRtcpSend(kAudioReceiverSsrc); + router()->RequestRtpSend(kVideoReceiverSsrc); + + // Run the SenderPacketRouter for 3 seconds. + constexpr Clock::duration kSimulationDuration = seconds(3); + constexpr Clock::duration kSimulationStepPeriod = milliseconds(1); + const Clock::time_point start_time = env()->now(); + RunTasksUntilIdle(); + const Clock::time_point end_time = start_time + kSimulationDuration; + while (env()->now() <= end_time) { + AdvanceClockAndRunTasks(kSimulationStepPeriod); + } + + // Examine the packets that were actually sent, and confirm that the priority + // ordering was maintained. + ASSERT_EQ(384, static_cast<int>(packets_sent.size())); + // The very first packet sent should be an audio RTCP packet. + EXPECT_EQ('3', ParseFlag(packets_sent[0])); + EXPECT_EQ(start_time, ParseTimestamp(packets_sent[0])); + // Scan the rest, checking that packets sent in the same burst (i.e., having + // the same send timestamp) were sent in priority order. + char last_priority_flag = '3'; + Clock::time_point last_timestamp = start_time; + for (int i = 1; i < static_cast<int>(packets_sent.size()) && + !testing::Test::HasFailure(); + ++i) { + const char priority_flag = ParseFlag(packets_sent[i]); + const Clock::time_point timestamp = ParseTimestamp(packets_sent[i]); + EXPECT_LE(last_timestamp, timestamp) << "packet[" << i << ']'; + if (timestamp == last_timestamp) { + EXPECT_GT(last_priority_flag, priority_flag) << "packet[" << i << ']'; + } + last_priority_flag = priority_flag; + last_timestamp = timestamp; + } + + router()->OnSenderDestroyed(kVideoReceiverSsrc); + router()->OnSenderDestroyed(kAudioReceiverSsrc); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc index 56d100d8b17..5e230da45a4 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc @@ -7,8 +7,8 @@ #include "cast/streaming/packet_util.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { SenderReportBuilder::SenderReportBuilder(RtcpSession* session) : session_(session) { @@ -52,5 +52,33 @@ std::pair<absl::Span<uint8_t>, StatusReportId> SenderReportBuilder::BuildPacket( ToStatusReportId(ntp_timestamp)); } -} // namespace streaming +Clock::time_point SenderReportBuilder::GetRecentReportTime( + StatusReportId report_id, + Clock::time_point on_or_before) const { + // Assumption: The |report_id| is the middle 32 bits of a 64-bit NtpTimestamp. + static_assert(ToStatusReportId(NtpTimestamp{0x0192a3b4c5d6e7f8}) == + StatusReportId{0xa3b4c5d6}, + "FIXME: ToStatusReportId() implementation changed."); + + // Compute the maximum possible NtpTimestamp. Then, use its uppermost 16 bits + // and the 32 bits from the report_id to produce a reconstructed NtpTimestamp. + const NtpTimestamp max_timestamp = + session_->ntp_converter().ToNtpTimestamp(on_or_before); + // max_timestamp: HH...... + // report_id: LLLL + // ↓↓ ↙↙↙↙ + // reconstructed: HHLLLL00 + NtpTimestamp reconstructed = (max_timestamp & (uint64_t{0xffff} << 48)) | + (static_cast<uint64_t>(report_id) << 16); + // If the reconstructed timestamp is greater than the maximum one, rollover + // of the lower 48 bits occurred. Subtract one from the upper 16 bits to + // rectify that. + if (reconstructed > max_timestamp) { + reconstructed -= uint64_t{1} << 48; + } + + return session_->ntp_converter().ToLocalTime(reconstructed); +} + } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h index 724e7342628..c5367783a95 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h @@ -13,9 +13,10 @@ #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtcp_session.h" #include "cast/streaming/rtp_defines.h" +#include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { // Builds RTCP packets containing one Sender Report. class SenderReportBuilder { @@ -31,6 +32,11 @@ class SenderReportBuilder { const RtcpSenderReport& sender_report, absl::Span<uint8_t> buffer) const; + // Returns the approximate reference time from a recently-built Sender Report, + // based on the given |report_id| and maximum possible reference time. + Clock::time_point GetRecentReportTime(StatusReportId report_id, + Clock::time_point on_or_before) const; + // The required size (in bytes) of the buffer passed to BuildPacket(). static constexpr int kRequiredBufferSize = kRtcpCommonHeaderSize + kRtcpSenderReportSize + kRtcpReportBlockSize; @@ -39,7 +45,7 @@ class SenderReportBuilder { RtcpSession* const session_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_SENDER_REPORT_BUILDER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc index 4770dd6bacd..d917486371e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc @@ -7,8 +7,8 @@ #include "cast/streaming/packet_util.h" #include "util/logging.h" +namespace openscreen { namespace cast { -namespace streaming { SenderReportParser::SenderReportWithId::SenderReportWithId() = default; SenderReportParser::SenderReportWithId::~SenderReportWithId() = default; @@ -71,5 +71,5 @@ SenderReportParser::Parse(absl::Span<const uint8_t> buffer) { return sender_report; } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h index df7f89a0e25..3c66829c6b4 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h @@ -12,8 +12,8 @@ #include "cast/streaming/rtp_defines.h" #include "cast/streaming/rtp_time.h" +namespace openscreen { namespace cast { -namespace streaming { // Parses RTCP packets from a Sender to extract Sender Reports. Ignores anything // else, since that is all a Receiver would be interested in. @@ -45,7 +45,7 @@ class SenderReportParser { RtpTimeTicks last_parsed_rtp_timestamp_; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_SENDER_REPORT_PARSER_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc index 8b9d8e74057..c79eaa6aa72 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc @@ -8,10 +8,10 @@ #include "platform/api/time.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - using cast::streaming::RtcpSenderReport; - using cast::streaming::RtcpSession; - using cast::streaming::SenderReportParser; - using cast::streaming::Ssrc; + using openscreen::cast::RtcpSenderReport; + using openscreen::cast::RtcpSession; + using openscreen::cast::SenderReportParser; + using openscreen::cast::Ssrc; constexpr Ssrc kSenderSsrcInSeedCorpus = 1; constexpr Ssrc kReceiverSsrcInSeedCorpus = 2; @@ -24,7 +24,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wexit-time-destructors" static RtcpSession session(kSenderSsrcInSeedCorpus, kReceiverSsrcInSeedCorpus, - openscreen::platform::Clock::now()); + openscreen::Clock::now()); static SenderReportParser parser(&session); #pragma clang diagnostic pop diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc index bad2cf0afa4..f9ca8298d0c 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc @@ -9,10 +9,12 @@ #include "cast/streaming/sender_report_parser.h" #include "gtest/gtest.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { +using openscreen::operator<<; + constexpr Ssrc kSenderSsrc{1}; constexpr Ssrc kReceiverSsrc{2}; @@ -25,8 +27,7 @@ class SenderReportTest : public testing::Test { } private: - RtcpSession session_{kSenderSsrc, kReceiverSsrc, - openscreen::platform::Clock::now()}; + RtcpSession session_{kSenderSsrc, kReceiverSsrc, Clock::now()}; SenderReportBuilder builder_{&session_}; SenderReportParser parser_{&session_}; }; @@ -120,7 +121,7 @@ TEST_F(SenderReportTest, BuildPackets) { const bool with_report_block = (i == 1); RtcpSenderReport original; - original.reference_time = openscreen::platform::Clock::now(); + original.reference_time = Clock::now(); original.rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(5); original.send_packet_count = 55; original.send_octet_count = 20044; @@ -161,6 +162,32 @@ TEST_F(SenderReportTest, BuildPackets) { } } +TEST_F(SenderReportTest, ComputesTimePointsFromReportIds) { + // Note: The time_points can be off by up to 16 µs because of the loss of + // precision caused by truncating the NtpTimestamps into StatusReportIds. + constexpr std::chrono::microseconds kEpsilon{16}; + + // Test a sampling of time points over the last 65536 seconds to confirm the + // rollover correction logic is working. + Clock::time_point on_or_before = Clock::now() + std::chrono::seconds(65536); + constexpr int kNumIterations = 16; + constexpr int kSecondsPerStep = 4096; + for (int i = 0; i < kNumIterations; ++i) { + const Clock::time_point expected_time = + on_or_before - std::chrono::seconds(i * kSecondsPerStep); + const auto report_id = + ToStatusReportId(ntp_converter().ToNtpTimestamp(expected_time)); + const Clock::time_point report_time = + builder()->GetRecentReportTime(report_id, on_or_before); + EXPECT_GE(on_or_before, report_time); + const auto absolute_difference = (expected_time < report_time) + ? (report_time - expected_time) + : (expected_time - report_time); + EXPECT_LE(absolute_difference, kEpsilon) + << expected_time << " vs " << report_time; + } +} + } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc new file mode 100644 index 00000000000..1c93fceed8c --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc @@ -0,0 +1,1138 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/sender.h" + +#include <stdint.h> + +#include <algorithm> +#include <array> +#include <chrono> // NOLINT +#include <limits> +#include <map> +#include <set> +#include <vector> + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "cast/streaming/compound_rtcp_builder.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/encoded_frame.h" +#include "cast/streaming/frame_collector.h" +#include "cast/streaming/frame_crypto.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/mock_environment.h" +#include "cast/streaming/packet_util.h" +#include "cast/streaming/rtcp_session.h" +#include "cast/streaming/rtp_defines.h" +#include "cast/streaming/rtp_packet_parser.h" +#include "cast/streaming/sender_packet_router.h" +#include "cast/streaming/sender_report_parser.h" +#include "cast/streaming/session_config.h" +#include "cast/streaming/ssrc.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "util/alarm.h" +#include "util/yet_another_bit_vector.h" + +using std::chrono::duration_cast; +using std::chrono::microseconds; +using std::chrono::milliseconds; +using std::chrono::nanoseconds; +using std::chrono::seconds; + +using testing::_; +using testing::AtLeast; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::Mock; +using testing::NiceMock; +using testing::Return; +using testing::Sequence; + +namespace openscreen { +namespace cast { +namespace { + +// Sender configuration. +constexpr Ssrc kSenderSsrc = 1; +constexpr Ssrc kReceiverSsrc = 2; +constexpr int kRtpTimebase = 48000; +constexpr milliseconds kTargetPlayoutDelay{400}; +constexpr auto kAesKey = + std::array<uint8_t, 16>{{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}}; +constexpr auto kCastIvMask = + std::array<uint8_t, 16>{{0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80, + 0x70, 0x60, 0x50, 0x40, 0x30, 0x20, 0x10, 0x00}}; +constexpr RtpPayloadType kRtpPayloadType = RtpPayloadType::kVideoVp8; + +// The number of RTP ticks advanced per frame, for 100 FPS media. +constexpr int kRtpTicksPerFrame = kRtpTimebase / 100; + +// The number of milliseconds advanced per frame, for 100 FPS media. +constexpr milliseconds kFrameDuration{1000 / 100}; +static_assert(kFrameDuration < (kTargetPlayoutDelay / 10), + "Kickstart test assumes frame duration is far less than the " + "playout delay."); + +// An Encoded frame that also holds onto its own copy of data. +struct EncodedFrameWithBuffer : public EncodedFrame { + // |EncodedFrame::data| always points inside buffer.begin()...buffer.end(). + std::vector<uint8_t> buffer; +}; + +// SenderPacketRouter configuration for these tests. +constexpr int kNumPacketsPerBurst = 20; +constexpr milliseconds kBurstInterval{10}; + +// An arbitrary value, subtracted from "now," to specify the reference_time on +// frames that are about to be enqueued. This simulates that capture+encode +// happened in the past, before Sender::EnqueueFrame() is called. +constexpr milliseconds kCaptureDelay{11}; + +// In some tests, the computed time values could be off a little bit due to +// imprecision in certain wire-format timestamp types. The following macro +// behaves just like Gtest's EXPECT_NEAR(), but works with all the time types +// too. +#define EXPECT_NEARLY_EQUAL(duration_a, duration_b, epsilon) \ + if ((duration_a) >= (duration_b)) { \ + EXPECT_LE((duration_a), (duration_b) + (epsilon)); \ + } else { \ + EXPECT_GE((duration_a), (duration_b) - (epsilon)); \ + } + +// Simulates UDP/IPv6 traffic in one direction (from Sender→Receiver, or +// Receiver→Sender), with a settable amount of delay. +class SimulatedNetworkPipe { + public: + SimulatedNetworkPipe(TaskRunner* task_runner, + Environment::PacketConsumer* remote) + : task_runner_(task_runner), remote_(remote) { + // Create a fake IPv6 address using the "documentative purposes" prefix + // concatenated with the |this| pointer. + std::array<uint16_t, 8> hextets{}; + hextets[0] = 0x2001; + hextets[1] = 0x0db8; + auto* const this_pointer = this; + static_assert(sizeof(this_pointer) <= (6 * sizeof(uint16_t)), ""); + memcpy(&hextets[2], &this_pointer, sizeof(this_pointer)); + local_endpoint_ = IPEndpoint{IPAddress(hextets), 2344}; + } + + const IPEndpoint& local_endpoint() const { return local_endpoint_; } + + Clock::duration network_delay() const { return network_delay_; } + void set_network_delay(Clock::duration delay) { network_delay_ = delay; } + + // The caller needs to spin the task runner before |packet| will reach the + // other side. + void StartPacketTransmission(std::vector<uint8_t> packet) { + task_runner_->PostTaskWithDelay( + [this, packet = std::move(packet)]() mutable { + remote_->OnReceivedPacket(local_endpoint_, FakeClock::now(), + std::move(packet)); + }, + network_delay_); + } + + private: + TaskRunner* const task_runner_; + Environment::PacketConsumer* const remote_; + + IPEndpoint local_endpoint_; + + // The amount of time for the packet to transmit over this simulated network + // pipe. Defaults to zero to simplify the tests that don't care about delays. + Clock::duration network_delay_{}; +}; + +// Processes packets from the Sender under test, allowing unit tests to set +// expectations for parsed RTP or RTCP packets, to confirm proper behavior of +// the Sender. +class MockReceiver : public Environment::PacketConsumer { + public: + explicit MockReceiver(SimulatedNetworkPipe* pipe_to_sender) + : pipe_to_sender_(pipe_to_sender), + rtcp_session_(kSenderSsrc, kReceiverSsrc, FakeClock::now()), + sender_report_parser_(&rtcp_session_), + rtcp_builder_(&rtcp_session_), + rtp_parser_(kSenderSsrc), + crypto_(kAesKey, kCastIvMask) { + rtcp_builder_.SetPlayoutDelay(kTargetPlayoutDelay); + } + + ~MockReceiver() override = default; + + // Simulate the Receiver ACK'ing all frames up to and including the + // |new_checkpoint|. + void SetCheckpointFrame(FrameId new_checkpoint) { + OSP_CHECK_GE(new_checkpoint, rtcp_builder_.checkpoint_frame()); + rtcp_builder_.SetCheckpointFrame(new_checkpoint); + } + + // Automatically advances the checkpoint based on what is found in + // |complete_frames_|, returning true if the checkpoint moved forward. + bool AutoAdvanceCheckpoint() { + const FrameId old_checkpoint = rtcp_builder_.checkpoint_frame(); + FrameId new_checkpoint = old_checkpoint; + for (auto it = complete_frames_.upper_bound(old_checkpoint); + it != complete_frames_.end(); ++it) { + if (it->first != new_checkpoint + 1) { + break; + } + ++new_checkpoint; + } + if (new_checkpoint > old_checkpoint) { + rtcp_builder_.SetCheckpointFrame(new_checkpoint); + return true; + } + return false; + } + + void SetPictureLossIndicator(bool picture_is_lost) { + rtcp_builder_.SetPictureLossIndicator(picture_is_lost); + } + + void SetReceiverReport(StatusReportId reply_for, + RtcpReportBlock::Delay processing_delay) { + RtcpReportBlock receiver_report; + receiver_report.ssrc = kSenderSsrc; + receiver_report.last_status_report_id = reply_for; + receiver_report.delay_since_last_report = processing_delay; + rtcp_builder_.IncludeReceiverReportInNextPacket(receiver_report); + } + + void SetNacksAndAcks(std::vector<PacketNack> packet_nacks, + std::vector<FrameId> frame_acks) { + rtcp_builder_.IncludeFeedbackInNextPacket(std::move(packet_nacks), + std::move(frame_acks)); + } + + // Builds and sends a RTCP packet containing one or more of: checkpoint, PLI, + // Receiver Report, NACKs, ACKs. + void TransmitRtcpFeedbackPacket() { + uint8_t buffer[kMaxRtpPacketSizeForIpv6UdpOnEthernet]; + const absl::Span<uint8_t> packet = + rtcp_builder_.BuildPacket(FakeClock::now(), buffer); + pipe_to_sender_->StartPacketTransmission( + std::vector<uint8_t>(packet.begin(), packet.end())); + } + + // Used by tests to simulate the Receiver not seeing specific packets come in + // from the network (e.g., because the network dropped the packets). + void SetIgnoreList(std::vector<PacketNack> ignore_list) { + ignore_list_ = ignore_list; + } + + // Environment::PacketConsumer implementation. + // + // Called to process a packet from the Sender, simulating basic RTP frame + // collection and Sender Report parsing/handling. + void OnReceivedPacket(const IPEndpoint& source, + Clock::time_point arrival_time, + std::vector<uint8_t> packet) override { + const auto type_and_ssrc = InspectPacketForRouting(packet); + EXPECT_NE(ApparentPacketType::UNKNOWN, type_and_ssrc.first); + EXPECT_EQ(kSenderSsrc, type_and_ssrc.second); + if (type_and_ssrc.first == ApparentPacketType::RTP) { + const absl::optional<RtpPacketParser::ParseResult> part_of_frame = + rtp_parser_.Parse(packet); + ASSERT_TRUE(part_of_frame); + + // Return early if simulating packet drops over the network. + if (std::find_if(ignore_list_.begin(), ignore_list_.end(), + [&](const PacketNack& baddie) { + return ( + baddie.frame_id == part_of_frame->frame_id && + (baddie.packet_id == kAllPacketsLost || + baddie.packet_id == part_of_frame->packet_id)); + }) != ignore_list_.end()) { + return; + } + + OnRtpPacket(*part_of_frame); + CollectRtpPacket(*part_of_frame, std::move(packet)); + } else if (type_and_ssrc.first == ApparentPacketType::RTCP) { + absl::optional<SenderReportParser::SenderReportWithId> report = + sender_report_parser_.Parse(packet); + ASSERT_TRUE(report); + OnSenderReport(*report); + } + } + + std::map<FrameId, EncodedFrameWithBuffer> TakeCompleteFrames() { + std::map<FrameId, EncodedFrameWithBuffer> result; + result.swap(complete_frames_); + return result; + } + + // Tests set expectations on these mocks to monitor events of interest, and/or + // invoke additional behaviors. + MOCK_METHOD1(OnRtpPacket, + void(const RtpPacketParser::ParseResult& parsed_packet)); + MOCK_METHOD1(OnFrameComplete, void(FrameId frame_id)); + MOCK_METHOD1(OnSenderReport, + void(const SenderReportParser::SenderReportWithId& report)); + + private: + // Collects the individual RTP packets until a whole frame can be formed, then + // calls OnFrameComplete(). Ignores extra RTP packets that are no longer + // needed. + void CollectRtpPacket(const RtpPacketParser::ParseResult& part_of_frame, + std::vector<uint8_t> packet) { + const FrameId frame_id = part_of_frame.frame_id; + if (complete_frames_.find(frame_id) != complete_frames_.end()) { + return; + } + FrameCollector& collector = incomplete_frames_[frame_id]; + collector.set_frame_id(frame_id); + EXPECT_TRUE(collector.CollectRtpPacket(part_of_frame, &packet)); + if (!collector.is_complete()) { + return; + } + const EncryptedFrame& encrypted = collector.PeekAtAssembledFrame(); + EncodedFrameWithBuffer* const decrypted = &complete_frames_[frame_id]; + // Note: Not setting decrypted->reference_time here since the logic around + // calculating the playout time is rather complex, and is definitely outside + // the scope of the testing being done in this module. Instead, end-to-end + // testing should exist elsewhere to confirm frame play-out times with real + // Receivers. + decrypted->buffer.resize(FrameCrypto::GetPlaintextSize(encrypted)); + decrypted->data = absl::Span<uint8_t>(decrypted->buffer); + crypto_.Decrypt(encrypted, decrypted); + incomplete_frames_.erase(frame_id); + OnFrameComplete(frame_id); + } + + SimulatedNetworkPipe* const pipe_to_sender_; + RtcpSession rtcp_session_; + SenderReportParser sender_report_parser_; + CompoundRtcpBuilder rtcp_builder_; + RtpPacketParser rtp_parser_; + FrameCrypto crypto_; + + std::vector<PacketNack> ignore_list_; + std::map<FrameId, FrameCollector> incomplete_frames_; + std::map<FrameId, EncodedFrameWithBuffer> complete_frames_; +}; + +class MockObserver : public Sender::Observer { + public: + MOCK_METHOD1(OnFrameCanceled, void(FrameId frame_id)); + MOCK_METHOD0(OnPictureLost, void()); +}; + +class SenderTest : public testing::Test { + public: + SenderTest() + : fake_clock_(Clock::now()), + task_runner_(&fake_clock_), + sender_environment_(&FakeClock::now, &task_runner_), + sender_packet_router_(&sender_environment_, + kNumPacketsPerBurst, + kBurstInterval), + sender_(&sender_environment_, + &sender_packet_router_, + {/* .sender_ssrc = */ kSenderSsrc, + /* .receiver_ssrc = */ kReceiverSsrc, + /* .rtp_timebase = */ kRtpTimebase, + /* .channels = */ 2, + /* .target_playout_delay = */ kTargetPlayoutDelay, + /* .aes_secret_key = */ kAesKey, + /* .aes_iv_mask = */ kCastIvMask}, + kRtpPayloadType), + receiver_to_sender_pipe_(&task_runner_, &sender_packet_router_), + receiver_(&receiver_to_sender_pipe_), + sender_to_receiver_pipe_(&task_runner_, &receiver_) { + sender_environment_.set_socket_error_handler( + [](Error error) { ASSERT_TRUE(error.ok()) << error; }); + sender_environment_.set_remote_endpoint( + receiver_to_sender_pipe_.local_endpoint()); + ON_CALL(sender_environment_, SendPacket(_)) + .WillByDefault(Invoke([this](absl::Span<const uint8_t> packet) { + sender_to_receiver_pipe_.StartPacketTransmission( + std::vector<uint8_t>(packet.begin(), packet.end())); + })); + } + + ~SenderTest() override = default; + + Sender* sender() { return &sender_; } + MockReceiver* receiver() { return &receiver_; } + + void SetReceiverToSenderNetworkDelay(Clock::duration delay) { + receiver_to_sender_pipe_.set_network_delay(delay); + } + + void SetSenderToReceiverNetworkDelay(Clock::duration delay) { + sender_to_receiver_pipe_.set_network_delay(delay); + } + + void SimulateExecution(Clock::duration how_long = Clock::duration::zero()) { + fake_clock_.Advance(how_long); + } + + static void PopulateFramePayloadBuffer(int seed, + int num_bytes, + std::vector<uint8_t>* payload) { + payload->clear(); + payload->reserve(num_bytes); + for (int i = 0; i < num_bytes; ++i) { + payload->push_back(static_cast<uint8_t>(seed + i)); + } + } + + static void PopulateFrameWithDefaults(FrameId frame_id, + Clock::time_point reference_time, + int seed, + int num_payload_bytes, + EncodedFrameWithBuffer* frame) { + frame->dependency = (frame_id == FrameId::first()) + ? EncodedFrame::KEY_FRAME + : EncodedFrame::DEPENDS_ON_ANOTHER; + frame->frame_id = frame_id; + frame->referenced_frame_id = frame->frame_id; + if (frame_id != FrameId::first()) { + --frame->referenced_frame_id; + } + frame->rtp_timestamp = + RtpTimeTicks() + (RtpTimeDelta::FromTicks(kRtpTicksPerFrame) * + (frame_id - FrameId::first())); + frame->reference_time = reference_time; + PopulateFramePayloadBuffer(seed, num_payload_bytes, &frame->buffer); + frame->data = absl::Span<uint8_t>(frame->buffer); + } + + // Confirms that all |sent_frames| exist in |received_frames|, with identical + // data and metadata. + static void ExpectFramesReceivedCorrectly( + absl::Span<EncodedFrameWithBuffer> sent_frames, + const std::map<FrameId, EncodedFrameWithBuffer> received_frames) { + ASSERT_EQ(sent_frames.size(), received_frames.size()); + + for (const EncodedFrameWithBuffer& sent_frame : sent_frames) { + SCOPED_TRACE(testing::Message() + << "Checking sent frame " << sent_frame.frame_id); + const auto received_it = received_frames.find(sent_frame.frame_id); + if (received_it == received_frames.end()) { + ADD_FAILURE() << "Did not receive frame."; + continue; + } + const EncodedFrame& received_frame = received_it->second; + EXPECT_EQ(sent_frame.dependency, received_frame.dependency); + EXPECT_EQ(sent_frame.referenced_frame_id, + received_frame.referenced_frame_id); + EXPECT_EQ(sent_frame.rtp_timestamp, received_frame.rtp_timestamp); + EXPECT_TRUE(sent_frame.data == received_frame.data); + } + } + + private: + FakeClock fake_clock_; + FakeTaskRunner task_runner_; + NiceMock<MockEnvironment> sender_environment_; + SenderPacketRouter sender_packet_router_; + Sender sender_; + SimulatedNetworkPipe receiver_to_sender_pipe_; + NiceMock<MockReceiver> receiver_; + SimulatedNetworkPipe sender_to_receiver_pipe_; +}; + +// Tests that the Sender can send EncodedFrames over an ideal network (i.e., low +// latency, no loss), and does so without having to transmit the same packet +// twice. +TEST_F(SenderTest, SendsFramesEfficiently) { + constexpr milliseconds kOneWayNetworkDelay{1}; + SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay); + SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay); + + // Expect that each packet is only sent once. + std::set<std::pair<FrameId, FramePacketId>> received_packets; + EXPECT_CALL(*receiver(), OnRtpPacket(_)) + .WillRepeatedly( + Invoke([&](const RtpPacketParser::ParseResult& parsed_packet) { + std::pair<FrameId, FramePacketId> id(parsed_packet.frame_id, + parsed_packet.packet_id); + const auto insert_result = received_packets.insert(id); + EXPECT_TRUE(insert_result.second) + << "Received duplicate packet: " << id.first << ':' + << static_cast<int>(id.second); + })); + + // Simulate normal frame ACK'ing behavior. + ON_CALL(*receiver(), OnFrameComplete(_)).WillByDefault(InvokeWithoutArgs([&] { + if (receiver()->AutoAdvanceCheckpoint()) { + receiver()->TransmitRtcpFeedbackPacket(); + } + })); + + NiceMock<MockObserver> observer; + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first())).Times(1); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1)).Times(1); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2)).Times(1); + sender()->SetObserver(&observer); + + EncodedFrameWithBuffer frames[3]; + constexpr int kFrameDataSizes[] = {8196, 12, 1900}; + for (int i = 0; i < 3; ++i) { + if (i == 0) { + EXPECT_TRUE(sender()->NeedsKeyFrame()); + } else { + EXPECT_FALSE(sender()->NeedsKeyFrame()); + } + PopulateFrameWithDefaults(FrameId::first() + i, + FakeClock::now() - kCaptureDelay, 0xbf - i, + kFrameDataSizes[i], &frames[i]); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i])); + SimulateExecution(kFrameDuration); + } + SimulateExecution(kTargetPlayoutDelay); + + ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames()); +} + +// Tests that the Sender correctly computes the current in-flight media +// duration, a backlog signal for clients. +TEST_F(SenderTest, ComputesInFlightMediaDuration) { + // With no frames enqueued, the in-flight media duration should be zero. + EXPECT_EQ(Clock::duration::zero(), + sender()->GetInFlightMediaDuration(RtpTimeTicks())); + EXPECT_EQ(Clock::duration::zero(), + sender()->GetInFlightMediaDuration( + RtpTimeTicks() + RtpTimeDelta::FromTicks(kRtpTicksPerFrame))); + + // Enqueue a frame. + EncodedFrameWithBuffer frame; + PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0, + 13 /* bytes */, &frame); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame)); + + // Now, the in-flight media duration should depend on the RTP timestamp of the + // next frame. + EXPECT_EQ(kFrameDuration, sender()->GetInFlightMediaDuration( + frame.rtp_timestamp + + RtpTimeDelta::FromTicks(kRtpTicksPerFrame))); + EXPECT_EQ(10 * kFrameDuration, + sender()->GetInFlightMediaDuration( + frame.rtp_timestamp + + RtpTimeDelta::FromTicks(10 * kRtpTicksPerFrame))); +} + +// Tests that the Sender computes the maximum in-flight media duration based on +// its analysis of current network conditions. By implication, this demonstrates +// that the Sender is also measuring the network round-trip time. +TEST_F(SenderTest, RespondsToNetworkLatencyChanges) { + // The expected maximum error in time calculations is one tick of the RTCP + // report block's delay type. + constexpr auto kEpsilon = + duration_cast<nanoseconds>(RtcpReportBlock::Delay(1)); + + // Before the Sender has the necessary information to compute the network + // round-trip time, GetMaxInFlightMediaDuration() will return half the target + // playout delay. + EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2, + sender()->GetMaxInFlightMediaDuration(), kEpsilon); + + // No network is perfect. Simulate different one-way network delays. + constexpr milliseconds kOutboundDelay{2}; + constexpr milliseconds kInboundDelay{4}; + constexpr milliseconds kRoundTripDelay = kOutboundDelay + kInboundDelay; + SetSenderToReceiverNetworkDelay(kOutboundDelay); + SetReceiverToSenderNetworkDelay(kInboundDelay); + + // Enqueue a frame in the Sender to start emitting periodic RTCP reports. + { + EncodedFrameWithBuffer frame; + PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0, + 1 /* byte */, &frame); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame)); + } + + // Run one network round-trip from Sender→Receiver→Sender. + StatusReportId sender_report_id{}; + EXPECT_CALL(*receiver(), OnSenderReport(_)) + .WillOnce(Invoke( + [&](const SenderReportParser::SenderReportWithId& sender_report) { + sender_report_id = sender_report.report_id; + })); + // Simulate the passage of time for the Sender Report to reach the Receiver. + SimulateExecution(kOutboundDelay); + // The Receiver should have received the Sender Report at this point. + Mock::VerifyAndClearExpectations(receiver()); + ASSERT_NE(StatusReportId{}, sender_report_id); + // Simulate the passage of time in the Receiver doing "other tasks" before + // replying back to the Sender. This delay is included in the Receiver Report + // so that the Sender can isolate the delays caused by the network. + constexpr milliseconds kReceiverProcessingDelay{2}; + SimulateExecution(kReceiverProcessingDelay); + // Create the Receiver Report "reply," and simulate it being sent across the + // network, back to the Sender. + receiver()->SetReceiverReport( + sender_report_id, + duration_cast<RtcpReportBlock::Delay>(kReceiverProcessingDelay)); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(kInboundDelay); + + // At this point, the Sender should have computed the network round-trip time, + // and so GetMaxInFlightMediaDuration() will return half the target playout + // delay PLUS half the network round-trip time. + EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2 + kRoundTripDelay / 2, + sender()->GetMaxInFlightMediaDuration(), kEpsilon); + + // Increase the outbound delay, which will increase the total round-trip time. + constexpr milliseconds kIncreasedOutboundDelay{6}; + constexpr milliseconds kIncreasedRoundTripDelay = + kIncreasedOutboundDelay + kInboundDelay; + SetSenderToReceiverNetworkDelay(kIncreasedOutboundDelay); + + // With increased network delay, run several more network round-trips. Expect + // the Sender to gradually converge towards the new network round-trip time. + constexpr int kNumReportIntervals = 50; + EXPECT_CALL(*receiver(), OnSenderReport(_)) + .Times(kNumReportIntervals) + .WillRepeatedly(Invoke( + [&](const SenderReportParser::SenderReportWithId& sender_report) { + receiver()->SetReceiverReport(sender_report.report_id, + RtcpReportBlock::Delay::zero()); + receiver()->TransmitRtcpFeedbackPacket(); + })); + Clock::duration last_max = sender()->GetMaxInFlightMediaDuration(); + for (int i = 0; i < kNumReportIntervals; ++i) { + SimulateExecution(kRtcpReportInterval); + const Clock::duration updated_value = + sender()->GetMaxInFlightMediaDuration(); + EXPECT_LE(last_max, updated_value); + last_max = updated_value; + } + EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2 + kIncreasedRoundTripDelay / 2, + sender()->GetMaxInFlightMediaDuration(), kEpsilon); +} + +// Tests that the Sender rejects frames if too large a span of FrameIds would be +// in-flight at once. +TEST_F(SenderTest, RejectsEnqueuingBeforeProtocolDesignLimit) { + // For this test, use 1000 FPS. This makes the frames all one millisecond + // apart to avoid triggering the media-duration rejection logic. + constexpr int kFramesPerSecond = 1000; + constexpr milliseconds kFrameDuration{1}; + + const auto OverrideRtpTimestamp = [](int frame_count, EncodedFrame* frame) { + const int ticks = frame_count * kRtpTimebase / kFramesPerSecond; + frame->rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(ticks); + }; + + // Send the absolute design-limit maximum number of frames. + int frame_count = 0; + for (; frame_count < kMaxUnackedFrames; ++frame_count) { + EncodedFrameWithBuffer frame; + PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0, + 13 /* bytes */, &frame); + OverrideRtpTimestamp(frame_count, &frame); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame)); + SimulateExecution(kFrameDuration); + } + + // Now, attempting to enqueue just one more frame should fail. + EncodedFrameWithBuffer one_frame_too_much; + PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0, + 13 /* bytes */, &one_frame_too_much); + OverrideRtpTimestamp(frame_count++, &one_frame_too_much); + EXPECT_EQ(Sender::REACHED_ID_SPAN_LIMIT, + sender()->EnqueueFrame(one_frame_too_much)); + SimulateExecution(kFrameDuration); + + // Now, simulate the Receiver ACKing the first frame, and enqueuing should + // then succeed again. + receiver()->SetCheckpointFrame(FrameId::first()); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + EXPECT_EQ(Sender::OK, sender()->EnqueueFrame(one_frame_too_much)); + SimulateExecution(kFrameDuration); + + // Finally, attempting to enqueue another frame should fail again. + EncodedFrameWithBuffer another_frame_too_much; + PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0, + 13 /* bytes */, &another_frame_too_much); + OverrideRtpTimestamp(frame_count++, &another_frame_too_much); + EXPECT_EQ(Sender::REACHED_ID_SPAN_LIMIT, + sender()->EnqueueFrame(another_frame_too_much)); + SimulateExecution(kFrameDuration); +} + +// Tests that the Sender rejects frames if too-long a media duration is +// in-flight. This is the Sender's primary flow control mechanism. +TEST_F(SenderTest, RejectsEnqueuingIfTooLongMediaDurationIsInFlight) { + // For this test, use 20 FPS. This makes all frames 50 ms apart, which should + // make it easy to trigger the media-duration rejection logic. + constexpr int kFramesPerSecond = 20; + constexpr milliseconds kFrameDuration{50}; + + const auto OverrideRtpTimestamp = [](int frame_count, EncodedFrame* frame) { + const int ticks = frame_count * kRtpTimebase / kFramesPerSecond; + frame->rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(ticks); + }; + + // Enqueue frames until one is rejected because the in-flight duration would + // be too high. + EncodedFrameWithBuffer frame; + int frame_count = 0; + for (; frame_count < kMaxUnackedFrames; ++frame_count) { + PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0, + 13 /* bytes */, &frame); + OverrideRtpTimestamp(frame_count, &frame); + const auto result = sender()->EnqueueFrame(frame); + SimulateExecution(kFrameDuration); + if (result == Sender::MAX_DURATION_IN_FLIGHT) { + break; + } + ASSERT_EQ(Sender::OK, result); + } + + // Now, simulate the Receiver ACKing the first frame, and enqueuing should + // then succeed again. + receiver()->SetCheckpointFrame(FrameId::first()); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + EXPECT_EQ(Sender::OK, sender()->EnqueueFrame(frame)); + SimulateExecution(kFrameDuration); + + // However, attempting to enqueue another frame should fail again. + EncodedFrameWithBuffer one_frame_too_much; + PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0, + 13 /* bytes */, &one_frame_too_much); + OverrideRtpTimestamp(++frame_count, &one_frame_too_much); + EXPECT_EQ(Sender::MAX_DURATION_IN_FLIGHT, + sender()->EnqueueFrame(one_frame_too_much)); + SimulateExecution(kFrameDuration); +} + +// Tests that the Sender propagates the Receiver's picture loss indicator to the +// Observer::OnPictureLost(), and via calls to NeedsKeyFrame(); but only when +// producing a key frame is absolutely necessary. +TEST_F(SenderTest, ManagesReceiverPictureLossWorkflow) { + NiceMock<MockObserver> observer; + sender()->SetObserver(&observer); + + // Send three frames... + EncodedFrameWithBuffer frames[6]; + for (int i = 0; i < 3; ++i) { + if (i == 0) { + EXPECT_TRUE(sender()->NeedsKeyFrame()); + } else { + EXPECT_FALSE(sender()->NeedsKeyFrame()); + } + PopulateFrameWithDefaults(FrameId::first() + i, + FakeClock::now() - kCaptureDelay, 0, + 24 /* bytes */, &frames[i]); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i])); + SimulateExecution(kFrameDuration); + } + SimulateExecution(kTargetPlayoutDelay); + + // Simulate the Receiver ACK'ing the first three frames. + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first())).Times(1); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1)).Times(1); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2)).Times(1); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + receiver()->SetCheckpointFrame(frames[2].frame_id); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + Mock::VerifyAndClearExpectations(&observer); + + // Simulate something going wrong in the Receiver, and have it report picture + // loss to the Sender. The Sender should then propagate this to its Observer + // and return true when NeedsKeyFrame() is called. + EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0); + EXPECT_CALL(observer, OnPictureLost()).Times(1); + EXPECT_FALSE(sender()->NeedsKeyFrame()); + receiver()->SetPictureLossIndicator(true); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + Mock::VerifyAndClearExpectations(&observer); + EXPECT_TRUE(sender()->NeedsKeyFrame()); + + // Send a non-key frame, and expect NeedsKeyFrame() still returns true. The + // Observer is not re-notified. This accounts for the case where a client's + // media encoder had frames in its processing pipeline before NeedsKeyFrame() + // began returning true. + EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + EncodedFrameWithBuffer& nonkey_frame = frames[3]; + PopulateFrameWithDefaults(FrameId::first() + 3, + FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */, + &nonkey_frame); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(nonkey_frame)); + SimulateExecution(kFrameDuration); + Mock::VerifyAndClearExpectations(&observer); + EXPECT_TRUE(sender()->NeedsKeyFrame()); + + // Now send a key frame, and expect NeedsKeyFrame() returns false. Note that + // the Receiver hasn't cleared the PLI condition, but the Sender knows more + // key frames won't be needed. + EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + EncodedFrameWithBuffer& recovery_frame = frames[4]; + PopulateFrameWithDefaults(FrameId::first() + 4, + FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */, + &recovery_frame); + recovery_frame.dependency = EncodedFrame::KEY_FRAME; + recovery_frame.referenced_frame_id = recovery_frame.frame_id; + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(recovery_frame)); + SimulateExecution(kFrameDuration); + Mock::VerifyAndClearExpectations(&observer); + EXPECT_FALSE(sender()->NeedsKeyFrame()); + + // Let's say the Receiver hasn't received the key frame yet, and it reports + // its picture loss again to the Sender. Observer::OnPictureLost() should not + // be called, and NeedsKeyFrame() should NOT return true, because the Sender + // knows the Receiver hasn't acknowledged the key frame (just sent) yet. + EXPECT_CALL(observer, OnFrameCanceled(nonkey_frame.frame_id)).Times(1); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + receiver()->SetCheckpointFrame(nonkey_frame.frame_id); + receiver()->SetPictureLossIndicator(true); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + Mock::VerifyAndClearExpectations(&observer); + EXPECT_FALSE(sender()->NeedsKeyFrame()); + + // Now, simulate the Receiver getting the key frame, but NOT recovering. This + // should cause Observer::OnPictureLost() to be called, and cause + // NeedsKeyFrame() to return true again. + EXPECT_CALL(observer, OnFrameCanceled(recovery_frame.frame_id)).Times(1); + EXPECT_CALL(observer, OnPictureLost()).Times(1); + receiver()->SetCheckpointFrame(recovery_frame.frame_id); + receiver()->SetPictureLossIndicator(true); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + Mock::VerifyAndClearExpectations(&observer); + EXPECT_TRUE(sender()->NeedsKeyFrame()); + + // Send another key frame, and expect NeedsKeyFrame() returns false. + EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + EncodedFrameWithBuffer& another_recovery_frame = frames[5]; + PopulateFrameWithDefaults(FrameId::first() + 5, + FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */, + &another_recovery_frame); + another_recovery_frame.dependency = EncodedFrame::KEY_FRAME; + another_recovery_frame.referenced_frame_id = another_recovery_frame.frame_id; + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(another_recovery_frame)); + SimulateExecution(kFrameDuration); + Mock::VerifyAndClearExpectations(&observer); + EXPECT_FALSE(sender()->NeedsKeyFrame()); + + // Now, simulate the Receiver recovering. It will report this to the Sender, + // and NeedsKeyFrame() will still return false. + EXPECT_CALL(observer, OnFrameCanceled(another_recovery_frame.frame_id)) + .Times(1); + EXPECT_CALL(observer, OnPictureLost()).Times(0); + receiver()->SetCheckpointFrame(another_recovery_frame.frame_id); + receiver()->SetPictureLossIndicator(false); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + Mock::VerifyAndClearExpectations(&observer); + EXPECT_FALSE(sender()->NeedsKeyFrame()); + + ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames()); +} + +// Tests that the Receiver should get a Sender Report just before the first RTP +// packet, and at regular intervals thereafter. The Sender Report contains the +// lip-sync information necessary for play-out timing. +TEST_F(SenderTest, ProvidesSenderReports) { + std::vector<SenderReportParser::SenderReportWithId> sender_reports; + Sequence packet_sequence; + EXPECT_CALL(*receiver(), OnSenderReport(_)) + .InSequence(packet_sequence) + .WillOnce( + Invoke([&](const SenderReportParser::SenderReportWithId& report) { + sender_reports.push_back(report); + })) + .RetiresOnSaturation(); + EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(1).InSequence(packet_sequence); + EXPECT_CALL(*receiver(), OnSenderReport(_)) + .Times(3) + .InSequence(packet_sequence) + .WillRepeatedly( + Invoke([&](const SenderReportParser::SenderReportWithId& report) { + sender_reports.push_back(report); + })); + + EncodedFrameWithBuffer frame; + constexpr int kFrameDataSize = 250; + PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0, + kFrameDataSize, &frame); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame)); + SimulateExecution(); // Should send one Sender Report + one RTP packet. + EXPECT_EQ(size_t{1}, sender_reports.size()); + + // Have the Receiver ACK the frame to prevent retransmitting the RTP packet. + receiver()->SetCheckpointFrame(FrameId::first()); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + + // Advance through three more reporting intervals. One Sender Report should be + // sent each interval, making a total of 4 reports sent. + constexpr auto kThreeReportIntervals = 3 * kRtcpReportInterval; + SimulateExecution(kThreeReportIntervals); // Three more Sender Reports. + ASSERT_EQ(size_t{4}, sender_reports.size()); + + // The first report should contain the same timestamps as the frame because + // the Clock did not advance. Also, its packet count and octet count fields + // should be zero since the report was sent before the RTP packet. + EXPECT_EQ(frame.reference_time, sender_reports.front().reference_time); + EXPECT_EQ(frame.rtp_timestamp, sender_reports.front().rtp_timestamp); + EXPECT_EQ(uint32_t{0}, sender_reports.front().send_packet_count); + EXPECT_EQ(uint32_t{0}, sender_reports.front().send_octet_count); + + // The last report should contain the timestamps extrapolated into the future + // because the Clock did move forward. Also, the packet count and octet fields + // should now be non-zero because the report was sent after the RTP packet. + EXPECT_EQ(frame.reference_time + kThreeReportIntervals, + sender_reports.back().reference_time); + EXPECT_EQ(frame.rtp_timestamp + + RtpTimeDelta::FromDuration(kThreeReportIntervals, kRtpTimebase), + sender_reports.back().rtp_timestamp); + EXPECT_EQ(uint32_t{1}, sender_reports.back().send_packet_count); + EXPECT_EQ(uint32_t{kFrameDataSize}, sender_reports.back().send_octet_count); +} + +// Tests that the Sender provides Kickstart packets whenever the Receiver may +// not know about new frames. +TEST_F(SenderTest, ProvidesKickstartPacketsIfReceiverDoesNotACK) { + // Have the Receiver move the checkpoint forward only for the first frame, and + // none of the later frames. This will force the Sender to eventually send a + // Kickstart packet. + ON_CALL(*receiver(), OnFrameComplete(_)) + .WillByDefault(Invoke([&](FrameId frame_id) { + if (frame_id == FrameId::first()) { + receiver()->SetCheckpointFrame(FrameId::first()); + receiver()->TransmitRtcpFeedbackPacket(); + } + })); + + // Send three frames, paced to the media. + EncodedFrameWithBuffer frames[3]; + for (int i = 0; i < 3; ++i) { + PopulateFrameWithDefaults(FrameId::first() + i, + FakeClock::now() - kCaptureDelay, i, + 48 /* bytes */, &frames[i]); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i])); + SimulateExecution(kFrameDuration); + } + + // Now, do nothing for a while. Because the Receiver isn't moving the + // checkpoint forward, the Sender will have sent all the RTP packets at least + // once, and then will start sending just Kickstart packets. + SimulateExecution(kTargetPlayoutDelay); + + // Keep doing nothing for a while, and confirm the Sender is just sending the + // same Kickstart packet over and over. The Kickstart packet is supposed to be + // the last packet of the latest frame. + std::set<std::pair<FrameId, FramePacketId>> unique_received_packet_ids; + EXPECT_CALL(*receiver(), OnRtpPacket(_)) + .WillRepeatedly( + Invoke([&](const RtpPacketParser::ParseResult& parsed_packet) { + unique_received_packet_ids.emplace(parsed_packet.frame_id, + parsed_packet.packet_id); + })); + SimulateExecution(kTargetPlayoutDelay); + Mock::VerifyAndClearExpectations(receiver()); + EXPECT_EQ(size_t{1}, unique_received_packet_ids.size()); + EXPECT_EQ(frames[2].frame_id, unique_received_packet_ids.begin()->first); + + // Now, simulate the Receiver ACKing all the frames. + receiver()->SetCheckpointFrame(frames[2].frame_id); + receiver()->TransmitRtcpFeedbackPacket(); + SimulateExecution(); // RTCP transmitted to Sender. + + // With all the frames sent, the Sender should not be transmitting anything. + EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0); + SimulateExecution(10 * kTargetPlayoutDelay); + + ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames()); +} + +// Tests that the Sender only retransmits packets specifically NACK'ed by the +// Receiver. +TEST_F(SenderTest, ResendsIndividuallyNackedPackets) { + // Populate the frame data in each frame with enough bytes to force at least + // three RTP packets per frame. + constexpr int kFrameDataSize = 3 * kMaxRtpPacketSizeForIpv6UdpOnEthernet; + + // Use a 1ms network delay in each direction to make the sequence of events + // clearer in this test. + constexpr milliseconds kOneWayNetworkDelay{1}; + SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay); + SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay); + + // Simulate that three specific packets will be dropped by the network, one + // from each frame (about to be sent). + const std::vector<PacketNack> dropped_packets{ + {FrameId::first(), FramePacketId{2}}, + {FrameId::first() + 1, FramePacketId{1}}, + {FrameId::first() + 2, FramePacketId{0}}, + }; + receiver()->SetIgnoreList(dropped_packets); + + // Send three frames, paced to the media. The Receiver won't completely + // receive any of these frames due to dropped packets. + EXPECT_CALL(*receiver(), OnFrameComplete(_)).Times(0); + EncodedFrameWithBuffer frames[3]; + for (int i = 0; i < 3; ++i) { + PopulateFrameWithDefaults(FrameId::first() + i, + FakeClock::now() - kCaptureDelay, i, + kFrameDataSize, &frames[i]); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i])); + SimulateExecution(kFrameDuration); + } + SimulateExecution(kTargetPlayoutDelay); + Mock::VerifyAndClearExpectations(receiver()); + EXPECT_EQ(3, sender()->GetInFlightFrameCount()); + + // The Receiver NACKs the three dropped packets... + receiver()->SetNacksAndAcks(dropped_packets, {}); + receiver()->TransmitRtcpFeedbackPacket(); + + // In the meantime, the network recovers (i.e., no more dropped packets)... + receiver()->SetIgnoreList({}); + + // The NACKs reach the Sender, and it acts on them by retransmitting. + SimulateExecution(kOneWayNetworkDelay); + + // As each retransmitted packet arrives at the Receiver, advance the + // checkpoint forward to notify the Sender of frames that are now completely + // received. Also, confirm that only the three specifically-NACK'ed packets + // were retransmitted. + EXPECT_CALL(*receiver(), OnFrameComplete(_)) + .Times(3) + .WillRepeatedly(InvokeWithoutArgs([&] { + if (receiver()->AutoAdvanceCheckpoint()) { + receiver()->TransmitRtcpFeedbackPacket(); + } + })); + EXPECT_CALL(*receiver(), OnRtpPacket(_)) + .Times(3) + .WillRepeatedly(Invoke([&](const RtpPacketParser::ParseResult& packet) { + EXPECT_FALSE(std::find(dropped_packets.begin(), dropped_packets.end(), + PacketNack{packet.frame_id, packet.packet_id}) == + dropped_packets.end()); + })); + SimulateExecution(kOneWayNetworkDelay); + Mock::VerifyAndClearExpectations(receiver()); + + // The Receiver checkpoint feedback(s) travel back to the Sender, and there + // should no longer be any frames in-flight. + SimulateExecution(kOneWayNetworkDelay); + EXPECT_EQ(0, sender()->GetInFlightFrameCount()); + + // The Sender should not be transmitting anything from now on since all frames + // are known to have been completely received. + EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0); + SimulateExecution(10 * kTargetPlayoutDelay); + + ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames()); +} + +// Tests that the Sender retransmits an entire frame if the Receiver requests it +// (i.e., a full frame NACK), but does not retransmit any packets for frames +// (before or after) that have been acknowledged. +TEST_F(SenderTest, ResendsMissingFrames) { + // Populate the frame data in each frame with enough bytes to force at least + // three RTP packets per frame. + constexpr int kFrameDataSize = 3 * kMaxRtpPacketSizeForIpv6UdpOnEthernet; + + // Use a 1ms network delay in each direction to make the sequence of events + // clearer in this test. + constexpr milliseconds kOneWayNetworkDelay{1}; + SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay); + SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay); + + // Simulate that all of the packets for the second frame will be dropped by + // the network, but only the packets for that frame. + const std::vector<PacketNack> dropped_packets{ + {FrameId::first() + 1, kAllPacketsLost}, + }; + receiver()->SetIgnoreList(dropped_packets); + + NiceMock<MockObserver> observer; + sender()->SetObserver(&observer); + + // The expectations below track the story and execute simulated Receiver + // responses. The Sender will have three frames enqueued by its client, and + // then... + // + // The first frame is received and the Receiver ACKs it by moving the + // checkpoint forward. + Sequence completion_sequence; + EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first())) + .InSequence(completion_sequence) + .WillOnce(InvokeWithoutArgs([&] { + receiver()->SetCheckpointFrame(FrameId::first()); + receiver()->TransmitRtcpFeedbackPacket(); + })); + // Since all of the packets for the second frame are being dropped, the third + // frame will finish next. The Receiver responds by NACKing the second frame + // and ACKing the third frame. The checkpoint does not move forward because + // the second frame has not been received yet. + // + // NETWORK CHANGE: After the third frame is received, stop dropping packets. + EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first() + 2)) + .InSequence(completion_sequence) + .WillOnce(InvokeWithoutArgs([&] { + receiver()->SetNacksAndAcks(dropped_packets, + std::vector<FrameId>{FrameId::first() + 2}); + receiver()->TransmitRtcpFeedbackPacket(); + receiver()->SetIgnoreList({}); + })); + // Finally, the Sender should respond to the whole-frame NACK by re-sending + // all of the packets for the second frame, and so the Receiver should + // completely receive the frame. + EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first() + 1)) + .InSequence(completion_sequence) + .WillOnce(InvokeWithoutArgs([&] { + receiver()->SetCheckpointFrame(FrameId::first() + 2); + receiver()->TransmitRtcpFeedbackPacket(); + })); + + // From the Sender's perspective, the Receiver will ACK the first frame, then + // the third frame, then the second frame. + Sequence cancel_sequence; + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first())) + .Times(1) + .InSequence(cancel_sequence); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2)) + .Times(1) + .InSequence(cancel_sequence); + EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1)) + .Times(1) + .InSequence(cancel_sequence); + + // With all the expectations/sequences in-place, let 'er rip! + EncodedFrameWithBuffer frames[3]; + for (int i = 0; i < 3; ++i) { + PopulateFrameWithDefaults(FrameId::first() + i, + FakeClock::now() - kCaptureDelay, i, + kFrameDataSize, &frames[i]); + ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i])); + SimulateExecution(kFrameDuration); + } + SimulateExecution(kTargetPlayoutDelay); + Mock::VerifyAndClearExpectations(receiver()); + EXPECT_EQ(0, sender()->GetInFlightFrameCount()); + + // The Sender should not be transmitting anything from now on since all frames + // are known to have been completely received. + EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0); + SimulateExecution(10 * kTargetPlayoutDelay); + + ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames()); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc index a6b09f003a5..651170294c1 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc @@ -4,21 +4,23 @@ #include "cast/streaming/session_config.h" +namespace openscreen { namespace cast { -namespace streaming { SessionConfig::SessionConfig(Ssrc sender_ssrc, Ssrc receiver_ssrc, int rtp_timebase, int channels, + std::chrono::milliseconds target_playout_delay, std::array<uint8_t, 16> aes_secret_key, std::array<uint8_t, 16> aes_iv_mask) : sender_ssrc(sender_ssrc), receiver_ssrc(receiver_ssrc), rtp_timebase(rtp_timebase), channels(channels), + target_playout_delay(target_playout_delay), aes_secret_key(aes_secret_key), aes_iv_mask(aes_iv_mask) {} -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.h b/chromium/third_party/openscreen/src/cast/streaming/session_config.h index d61efc11b35..4d611b65e2e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/session_config.h +++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.h @@ -6,12 +6,13 @@ #define CAST_STREAMING_SESSION_CONFIG_H_ #include <array> +#include <chrono> // NOLINT #include <cstdint> #include "cast/streaming/ssrc.h" +namespace openscreen { namespace cast { -namespace streaming { // Common streaming configuration, established from the OFFER/ANSWER exchange, // that the Sender and Receiver are both assuming. @@ -21,6 +22,7 @@ struct SessionConfig final { Ssrc receiver_ssrc, int rtp_timebase, int channels, + std::chrono::milliseconds target_playout_delay, std::array<uint8_t, 16> aes_secret_key, std::array<uint8_t, 16> aes_iv_mask); SessionConfig(const SessionConfig&) = default; @@ -42,12 +44,15 @@ struct SessionConfig final { // Number of channels. Must be 1 for video, for audio typically 2. int channels = 1; + // Initial target playout delay. + std::chrono::milliseconds target_playout_delay; + // The AES-128 crypto key and initialization vector. std::array<uint8_t, 16> aes_secret_key{}; std::array<uint8_t, 16> aes_iv_mask{}; }; -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_SESSION_CONFIG_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc b/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc index d3f2446906e..d71b806932e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc @@ -8,8 +8,8 @@ #include "platform/api/time.h" +namespace openscreen { namespace cast { -namespace streaming { namespace { @@ -28,7 +28,7 @@ Ssrc GenerateSsrc(bool higher_priority) { // it is light-weight and does not need to produce unguessable (nor // crypto-secure) values. static std::minstd_rand generator(static_cast<std::minstd_rand::result_type>( - openscreen::platform::Clock::now().time_since_epoch().count())); + Clock::now().time_since_epoch().count())); std::uniform_int_distribution<int> distribution( higher_priority ? kHigherPriorityMin : kNormalPriorityMin, @@ -40,5 +40,5 @@ int ComparePriority(Ssrc ssrc_a, Ssrc ssrc_b) { return static_cast<int>(ssrc_a) - static_cast<int>(ssrc_b); } -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc.h b/chromium/third_party/openscreen/src/cast/streaming/ssrc.h index 0cd235665d7..b84d7ac00d7 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ssrc.h +++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc.h @@ -7,8 +7,8 @@ #include <stdint.h> +namespace openscreen { namespace cast { -namespace streaming { // A Synchronization Source is a 32-bit opaque identifier used in RTP packets // for identifying the source (or recipient) of a logical sequence of encoded @@ -33,7 +33,7 @@ Ssrc GenerateSsrc(bool higher_priority); // ret > 0: Stream |ssrc_b| has higher priority. int ComparePriority(Ssrc ssrc_a, Ssrc ssrc_b); -} // namespace streaming } // namespace cast +} // namespace openscreen #endif // CAST_STREAMING_SSRC_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc index 29741409c64..aa9e50edfba 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc @@ -9,10 +9,8 @@ #include "gtest/gtest.h" #include "util/std_util.h" -using openscreen::SortAndDedupeElements; - +namespace openscreen { namespace cast { -namespace streaming { namespace { TEST(SsrcTest, GeneratesUniqueAndPrioritizedSsrcs) { @@ -53,5 +51,5 @@ TEST(SsrcTest, GeneratesUniqueAndPrioritizedSsrcs) { } } // namespace -} // namespace streaming } // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/test/BUILD.gn b/chromium/third_party/openscreen/src/cast/test/BUILD.gn new file mode 100644 index 00000000000..83950d2fd4a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/test/BUILD.gn @@ -0,0 +1,61 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//build_overrides/build.gni") + +source_set("unittests") { + testonly = true + sources = [ + "device_auth_test.cc", + ] + + deps = [ + "../../testing/util", + "../../third_party/googletest:gmock", + "../../third_party/googletest:gtest", + "../common:channel", + "../common:test_helpers", + "../common/channel/proto:channel_proto", + "../receiver:channel", + "../receiver:test_helpers", + "../sender:channel", + ] +} + +if (is_posix && !build_with_chromium) { + source_set("e2e_tests") { + testonly = true + sources = [ + "cast_socket_e2e_test.cc", + ] + + deps = [ + "../../platform", + "../../third_party/abseil", + "../../third_party/boringssl", + "../../third_party/googletest:gtest", + "../../util", + "../common:certificate", + "../common:channel", + "../common:test_helpers", + "../receiver:channel", + "../receiver:test_helpers", + "../sender:channel", + ] + } + + executable("make_crl_tests") { + testonly = true + sources = [ + "make_crl_tests.cc", + ] + + deps = [ + "../../third_party/boringssl", + "../../util", + "../common:test_helpers", + "../common/certificate/proto:certificate_proto", + ] + } +} diff --git a/chromium/third_party/openscreen/src/discovery/BUILD.gn b/chromium/third_party/openscreen/src/discovery/BUILD.gn index a7c1f3ef47f..27b5ef682e2 100644 --- a/chromium/third_party/openscreen/src/discovery/BUILD.gn +++ b/chromium/third_party/openscreen/src/discovery/BUILD.gn @@ -3,9 +3,31 @@ # found in the LICENSE file. import("//build_overrides/build.gni") +import("../testing/libfuzzer/fuzzer_test.gni") + +source_set("common") { + sources = [ + "common/config.h", + "common/reporting_client.h", + ] + + deps = [ + "../util", + ] + + public_deps = [ + "../platform", + "../third_party/abseil", + ] +} source_set("mdns") { sources = [ + "mdns/mdns_domain_confirmed_provider.h", + "mdns/mdns_probe.cc", + "mdns/mdns_probe.h", + "mdns/mdns_probe_manager.cc", + "mdns/mdns_probe_manager.h", "mdns/mdns_publisher.cc", "mdns/mdns_publisher.h", "mdns/mdns_querier.cc", @@ -28,6 +50,7 @@ source_set("mdns") { "mdns/mdns_writer.cc", "mdns/mdns_writer.h", "mdns/public/mdns_constants.h", + "mdns/public/mdns_service.cc", "mdns/public/mdns_service.h", ] @@ -36,14 +59,13 @@ source_set("mdns") { ] public_deps = [ + ":common", "../platform", "../third_party/abseil", ] } source_set("dnssd") { - defines = [] - sources = [ "dnssd/impl/conversion_layer.cc", "dnssd/impl/conversion_layer.h", @@ -69,14 +91,29 @@ source_set("dnssd") { ] public_deps = [ + ":common", ":mdns", ] } +source_set("public") { + sources = [ + "public/dns_sd_service_factory.h", + "public/dns_sd_service_publisher.h", + "public/dns_sd_service_watcher.h", + ] + + public_deps = [ + ":common", + ":dnssd", + ] +} + source_set("testing") { testonly = true sources = [ + "common/testing/mock_reporting_client.h", "dnssd/testing/fake_dns_record_factory.cc", "mdns/testing/mdns_test_util.cc", "mdns/testing/mdns_test_util.h", @@ -101,14 +138,19 @@ source_set("unittests") { "dnssd/impl/service_key_unittest.cc", "dnssd/public/dns_sd_instance_record_unittest.cc", "dnssd/public/dns_sd_txt_record_unittest.cc", + "mdns/mdns_probe_manager_unittest.cc", + "mdns/mdns_probe_unittest.cc", + "mdns/mdns_publisher_unittest.cc", "mdns/mdns_querier_unittest.cc", "mdns/mdns_random_unittest.cc", "mdns/mdns_reader_unittest.cc", "mdns/mdns_receiver_unittest.cc", "mdns/mdns_records_unittest.cc", + "mdns/mdns_responder_unittest.cc", "mdns/mdns_sender_unittest.cc", "mdns/mdns_trackers_unittest.cc", "mdns/mdns_writer_unittest.cc", + "public/dns_sd_service_watcher_unittest.cc", ] deps = [ @@ -120,3 +162,18 @@ source_set("unittests") { "../util", ] } + +openscreen_fuzzer_test("mdns_fuzzer") { + sources = [ + "mdns/mdns_reader_fuzztest.cc", + ] + + deps = [ + ":mdns", + ] + + seed_corpus = "mdns/fuzzer_seeds" + + # Note: 512 is the maximum size for a serialized mDNS packet. + libfuzzer_options = [ "max_len=512" ] +} diff --git a/chromium/third_party/openscreen/src/discovery/DEPS b/chromium/third_party/openscreen/src/discovery/DEPS index b9b5cb06605..de7afcec06c 100644 --- a/chromium/third_party/openscreen/src/discovery/DEPS +++ b/chromium/third_party/openscreen/src/discovery/DEPS @@ -1,9 +1,9 @@ # -*- Mode: Python; -*- include_rules = [ - '+cast', + # Intra-discovery dependencies must be explicit. + '-discovery', - # All libcast code can use platform and cast/third_party. - '+cast/third_party', - '+platform' + # All discovery code can use discovery/common + '+discovery/common', ] diff --git a/chromium/third_party/openscreen/src/discovery/common/config.h b/chromium/third_party/openscreen/src/discovery/common/config.h new file mode 100644 index 00000000000..4b0e13e76d8 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/common/config.h @@ -0,0 +1,56 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_COMMON_CONFIG_H_ +#define DISCOVERY_COMMON_CONFIG_H_ + +#include "platform/base/interface_info.h" + +namespace openscreen { +namespace discovery { + +// This struct provides parameters needed to initialize the discovery pipeline. +struct Config { + /***************************************** + * Networking Settings + *****************************************/ + + // Network Interface on which mDNS should be run. + InterfaceInfo interface; + + /***************************************** + * Publisher Settings + *****************************************/ + + // Determines whether publishing of services is enabled. + bool enable_publication = true; + + // Number of times new mDNS records should be announced, using an exponential + // back off. See RFC 6762 section 8.3 for further details. Per RFC, this value + // is expected to be in the range of 2 to 8. + int new_record_announcement_count = 8; + + /***************************************** + * Querier Settings + *****************************************/ + + // Determines whether querying is enabled. + bool enable_querying = true; + + // Number of times new mDNS records should be announced, using an exponential + // back off. -1 signifies that there should be no maximum. + // NOTE: This is expected to be -1 in all production scenarios and only be a + // different value during testing. + int new_query_announcement_count = -1; + + // Limit on the size to which the mDNS Querier Cache may grow. This is used to + // prevent a malicious or misbehaving mDNS client from causing the memory + // used by mDNS to grow in an unbounded fashion. + int querier_max_records_cached = 1024; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_COMMON_CONFIG_H_ diff --git a/chromium/third_party/openscreen/src/discovery/common/reporting_client.h b/chromium/third_party/openscreen/src/discovery/common/reporting_client.h new file mode 100644 index 00000000000..68011af5ba1 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/common/reporting_client.h @@ -0,0 +1,37 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_COMMON_REPORTING_CLIENT_H_ +#define DISCOVERY_COMMON_REPORTING_CLIENT_H_ + +#include "platform/base/error.h" + +namespace openscreen { +namespace discovery { + +// This class is implemented by the embedder who wishes to use discovery. The +// discovery implementation will use this API to report back errors and metrics. +// NOTE: All methods in the reporting client will be called from the task runner +// thread. +// TODO(rwkeane): Report state changes back to the caller. +class ReportingClient { + public: + virtual ~ReportingClient() = default; + + // This method is called when an error is detected by the underlying + // infrastructure from which recovery cannot be initiated. For example, an + // error binding a multicast socket. + virtual void OnFatalError(Error error) = 0; + + // This method is called when an error is detected by the underlying + // infrastructure which does not prevent further functionality of the runtime. + // For example, a conversion failure between DnsSdInstanceRecord and the + // externally supplied class. + virtual void OnRecoverableError(Error error) = 0; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_COMMON_REPORTING_CLIENT_H_ diff --git a/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h new file mode 100644 index 00000000000..4e3063c5d3a --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h @@ -0,0 +1,22 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_ +#define DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_ + +#include "discovery/common/reporting_client.h" +#include "gmock/gmock.h" + +namespace openscreen { +namespace discovery { + +class MockReportingClient : public ReportingClient { + MOCK_METHOD1(OnFatalError, void(Error error)); + MOCK_METHOD1(OnRecoverableError, void(Error error)); +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_ diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS new file mode 100644 index 00000000000..0c34a54b3d9 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS @@ -0,0 +1,6 @@ +# -*- Mode: Python; -*- + +include_rules = [ + '+discovery/dnssd/public', + '+discovery/mdns', +] diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc index ab64062197c..e15674e7b76 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc @@ -6,11 +6,14 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "discovery/dnssd/impl/constants.h" #include "discovery/dnssd/impl/instance_key.h" #include "discovery/dnssd/impl/service_key.h" #include "discovery/dnssd/public/dns_sd_instance_record.h" #include "discovery/mdns/mdns_records.h" +#include "discovery/mdns/public/mdns_constants.h" namespace openscreen { namespace discovery { @@ -42,62 +45,54 @@ DomainName GetInstanceDomainName(const std::string& instance, return DomainName{std::move(labels)}; } +inline DomainName GetInstanceDomainName(const InstanceKey& key) { + return GetInstanceDomainName(key.instance_id(), key.service_id(), + key.domain_id()); +} + MdnsRecord CreatePtrRecord(const DnsSdInstanceRecord& record, const DomainName& domain) { PtrRecordRdata data(domain); - - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds ttl(120); auto outer_domain = GetPtrDomainName(record.service_id(), record.domain_id()); return MdnsRecord(std::move(outer_domain), DnsType::kPTR, DnsClass::kIN, - RecordType::kShared, ttl, std::move(data)); + RecordType::kShared, kPtrRecordTtl, std::move(data)); } MdnsRecord CreateSrvRecord(const DnsSdInstanceRecord& record, const DomainName& domain) { uint16_t port = record.port(); - - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds ttl(120); SrvRecordRdata data(0, 0, port, domain); return MdnsRecord(domain, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique, - ttl, std::move(data)); + kSrvRecordTtl, std::move(data)); } absl::optional<MdnsRecord> CreateARecord(const DnsSdInstanceRecord& record, const DomainName& domain) { - if (!record.address_v4().has_value()) { + if (!record.address_v4()) { return absl::nullopt; } - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds ttl(120); - ARecordRdata data(record.address_v4().value().address); + ARecordRdata data(record.address_v4().address); return MdnsRecord(domain, DnsType::kA, DnsClass::kIN, RecordType::kUnique, - ttl, std::move(data)); + kARecordTtl, std::move(data)); } absl::optional<MdnsRecord> CreateAAAARecord(const DnsSdInstanceRecord& record, const DomainName& domain) { - if (!record.address_v6().has_value()) { + if (!record.address_v6()) { return absl::nullopt; } - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds ttl(120); - AAAARecordRdata data(record.address_v6().value().address); + AAAARecordRdata data(record.address_v6().address); return MdnsRecord(domain, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, - ttl, std::move(data)); + kAAAARecordTtl, std::move(data)); } MdnsRecord CreateTxtRecord(const DnsSdInstanceRecord& record, const DomainName& domain) { TxtRecordRdata data(record.txt().GetData()); - - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds ttl(75 * 60); return MdnsRecord(domain, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique, - ttl, std::move(data)); + kTXTRecordTtl, std::move(data)); } } // namespace @@ -121,8 +116,9 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) { std::string key = text.substr(0, index_of_eq); std::string value = text.substr(index_of_eq + 1); absl::Span<const uint8_t> data( - reinterpret_cast<const uint8_t*>(value.c_str()), value.size()); - const auto set_result = txt.SetValue(key, data); + reinterpret_cast<const uint8_t*>(value.data()), value.size()); + const auto set_result = + txt.SetValue(key, std::vector<uint8_t>(data.begin(), data.end())); if (!set_result.ok()) { return set_result; } @@ -137,10 +133,19 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) { return txt; } +DomainName GetDomainName(const InstanceKey& key) { + return GetInstanceDomainName(key.instance_id(), key.service_id(), + key.domain_id()); +} + +DomainName GetDomainName(const MdnsRecord& record) { + return IsPtrRecord(record) + ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain() + : record.name(); +} + DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key) { - auto domain = GetInstanceDomainName(key.instance_id(), key.service_id(), - key.domain_id()); - return {std::move(domain), DnsType::kANY, DnsClass::kANY}; + return {GetDomainName(key), DnsType::kANY, DnsClass::kANY}; } DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) { @@ -149,7 +154,12 @@ DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) { } bool HasValidDnsRecordAddress(const MdnsRecord& record) { - return InstanceKey::CreateFromRecord(record).is_value(); + return HasValidDnsRecordAddress(GetDomainName(record)); +} + +bool HasValidDnsRecordAddress(const DomainName& domain) { + return InstanceKey::TryCreate(domain).is_value() && + IsInstanceValid(domain.labels()[0]); } bool IsPtrRecord(const MdnsRecord& record) { @@ -157,8 +167,7 @@ bool IsPtrRecord(const MdnsRecord& record) { } std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstanceRecord& record) { - auto domain = GetInstanceDomainName(record.instance_id(), record.service_id(), - record.domain_id()); + auto domain = GetInstanceDomainName(InstanceKey(record)); std::vector<MdnsRecord> records{CreatePtrRecord(record, domain), CreateSrvRecord(record, domain), diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h index d1c34c6d3d0..08eaed116f4 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h @@ -30,14 +30,22 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt); bool IsPtrRecord(const MdnsRecord& record); -// Checks that the instance, service, and domain ids in this MdnsRecord are -// valid. +// Checks that the instance, service, and domain ids in this instance are valid. bool HasValidDnsRecordAddress(const MdnsRecord& record); +bool HasValidDnsRecordAddress(const DomainName& domain); //*** Conversions to DNS entities from DNS-SD Entities *** +// Returns the Domain Name associated with this InstanceKey. +DomainName GetDomainName(const InstanceKey& key); + +// Returns the domain name associated with this MdnsRecord. In the case of a PTR +// record, this is the target domain, and it is the named domain in all other +// cases. +DomainName GetDomainName(const MdnsRecord& record); + // Returns the query required to get all instance information about the service -// instances described by the provided ServiceKey. +// instances described by the provided InstanceKey. DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key); // Returns the query required to get all service information that matches the diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc index adc0abe70a2..b1f0fb245c4 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc @@ -41,9 +41,11 @@ TEST(DnsSdConversionLayerTest, TestCreateTxtValidKeyValue) { // EXPECT_STREQ is causing memory leaks std::string expected = "value"; - ASSERT_EQ(record.value().GetValue("name").value().size(), expected.size()); + ASSERT_TRUE(record.value().GetValue("name").is_value()); + const std::vector<uint8_t>& value = record.value().GetValue("name").value(); + ASSERT_EQ(value.size(), expected.size()); for (size_t i = 0; i < expected.size(); i++) { - EXPECT_EQ(expected[i], record.value().GetValue("name").value()[i]); + EXPECT_EQ(expected[i], value[i]); } } @@ -96,8 +98,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsPtr) { DnsSdTxtRecord txt; DnsSdInstanceRecord instance_record( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint, - FakeDnsRecordFactory::kV6Endpoint, txt); + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -129,8 +135,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsSrv) { DnsSdTxtRecord txt; DnsSdInstanceRecord instance_record( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint, - FakeDnsRecordFactory::kV6Endpoint, txt); + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -151,15 +161,19 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsSrv) { const auto& rdata = absl::get<SrvRecordRdata>(it->rdata()); EXPECT_EQ(rdata.priority(), 0); EXPECT_EQ(rdata.weight(), 0); - EXPECT_EQ(rdata.port(), FakeDnsRecordFactory::kV4Endpoint.port); + EXPECT_EQ(rdata.port(), FakeDnsRecordFactory::kPortNum); } TEST(DnsSdConversionLayerTest, GetDnsRecordsAPresent) { DnsSdTxtRecord txt; DnsSdInstanceRecord instance_record( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint, - FakeDnsRecordFactory::kV6Endpoint, txt); + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -178,15 +192,18 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAPresent) { EXPECT_EQ(it->name().labels()[3], FakeDnsRecordFactory::kDomainName); const auto& rdata = absl::get<ARecordRdata>(it->rdata()); - EXPECT_EQ(rdata.ipv4_address(), FakeDnsRecordFactory::kV4Address); + EXPECT_EQ(rdata.ipv4_address(), + IPAddress(FakeDnsRecordFactory::kV4AddressOctets)); } TEST(DnsSdConversionLayerTest, GetDnsRecordsANotPresent) { DnsSdTxtRecord txt; - DnsSdInstanceRecord instance_record(FakeDnsRecordFactory::kInstanceName, - FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, - FakeDnsRecordFactory::kV6Endpoint, txt); + DnsSdInstanceRecord instance_record( + FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -199,8 +216,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAAPresent) { DnsSdTxtRecord txt; DnsSdInstanceRecord instance_record( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint, - FakeDnsRecordFactory::kV6Endpoint, txt); + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -219,15 +240,18 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAAPresent) { EXPECT_EQ(it->name().labels()[3], FakeDnsRecordFactory::kDomainName); const auto& rdata = absl::get<AAAARecordRdata>(it->rdata()); - EXPECT_EQ(rdata.ipv6_address(), FakeDnsRecordFactory::kV6Address); + EXPECT_EQ(rdata.ipv6_address(), + IPAddress(FakeDnsRecordFactory::kV6AddressHextets)); } TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAANotPresent) { DnsSdTxtRecord txt; - DnsSdInstanceRecord instance_record(FakeDnsRecordFactory::kInstanceName, - FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, - FakeDnsRecordFactory::kV4Endpoint, txt); + DnsSdInstanceRecord instance_record( + FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -238,14 +262,17 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAANotPresent) { TEST(DnsSdConversionLayerTest, GetDnsRecordsTxt) { DnsSdTxtRecord txt; - auto value = - absl::Span<const uint8_t>(reinterpret_cast<const uint8_t*>("value"), 5); + std::vector<uint8_t> value{'v', 'a', 'l', 'u', 'e'}; txt.SetValue("name", value); txt.SetFlag("boolean", true); DnsSdInstanceRecord instance_record( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint, - FakeDnsRecordFactory::kV6Endpoint, txt); + FakeDnsRecordFactory::kDomainName, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), + FakeDnsRecordFactory::kPortNum}, + IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), + FakeDnsRecordFactory::kPortNum}, + txt); std::vector<MdnsRecord> records = GetDnsRecords(instance_record); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc index 400fe9e7275..fb1543b6b61 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc @@ -4,7 +4,7 @@ #include "discovery/dnssd/impl/dns_data.h" -#include <chrono> +#include <chrono> // NOLINT #include "discovery/mdns/testing/mdns_test_util.h" #include "gmock/gmock.h" @@ -66,24 +66,26 @@ class DnsDataTesting : public DnsData { } }; -static const IPAddress v4_address = - IPAddress(std::array<uint8_t, 4>{{192, 168, 0, 0}}); -static const IPAddress v6_address = IPAddress(std::array<uint8_t, 16>{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}}); -static const std::string instance_name = "instance"; -static const std::string service_name = "_srv-name._udp"; -static const std::string domain_name = "local"; -static const InstanceKey key = {instance_name, service_name, domain_name}; -static constexpr uint16_t port_num = uint16_t{80}; +namespace { + +const uint8_t kV4AddressOctets[4] = {192, 168, 0, 0}; +const uint16_t kV6AddressHextets[8] = {0x0102, 0x0304, 0x0506, 0x0708, + 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; +const char kInstanceName[] = "instance"; +const char kServiceName[] = "_srv-name._udp"; +const char kDomainName[] = "local"; +constexpr uint16_t kServicePort = uint16_t{80}; + +} // namespace DnsDataTesting CreateFullyPopulatedData() { - InstanceKey instance{instance_name, service_name, domain_name}; + InstanceKey instance{kInstanceName, kServiceName, kDomainName}; DnsDataTesting data(instance); - DomainName target{instance_name, "_srv-name", "_udp", domain_name}; - SrvRecordRdata srv(0, 0, port_num, target); + DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName}; + SrvRecordRdata srv(0, 0, kServicePort, target); TxtRecordRdata txt = MakeTxtRecord({"name=value", "boolValue"}); - ARecordRdata a(v4_address); - AAAARecordRdata aaaa(v6_address); + ARecordRdata a{IPAddress(kV4AddressOctets)}; + AAAARecordRdata aaaa{IPAddress(kV6AddressHextets)}; data.set_srv(srv); data.set_txt(txt); @@ -93,8 +95,8 @@ DnsDataTesting CreateFullyPopulatedData() { return data; } -MdnsRecord CreateFullyPopulatedRecord(uint16_t port = port_num) { - DomainName target{instance_name, "_srv-name", "_udp", domain_name}; +MdnsRecord CreateFullyPopulatedRecord(uint16_t port = kServicePort) { + DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName}; auto type = DnsType::kSRV; auto clazz = DnsClass::kIN; auto record_type = RecordType::kShared; @@ -110,15 +112,15 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataCorrectly) { ASSERT_TRUE(result.is_value()); DnsSdInstanceRecord record = result.value(); - ASSERT_TRUE(record.address_v4().has_value()); - ASSERT_TRUE(record.address_v6().has_value()); - EXPECT_EQ(record.instance_id(), instance_name); - EXPECT_EQ(record.service_id(), service_name); - EXPECT_EQ(record.domain_id(), domain_name); - EXPECT_EQ(record.address_v4().value().port, port_num); - EXPECT_EQ(record.address_v4().value().address, v4_address); - EXPECT_EQ(record.address_v6().value().port, port_num); - EXPECT_EQ(record.address_v6().value().address, v6_address); + ASSERT_TRUE(record.address_v4()); + ASSERT_TRUE(record.address_v6()); + EXPECT_EQ(record.instance_id(), kInstanceName); + EXPECT_EQ(record.service_id(), kServiceName); + EXPECT_EQ(record.domain_id(), kDomainName); + EXPECT_EQ(record.address_v4().port, kServicePort); + EXPECT_EQ(record.address_v4().address, IPAddress(kV4AddressOctets)); + EXPECT_EQ(record.address_v6().port, kServicePort); + EXPECT_EQ(record.address_v6().address, IPAddress(kV6AddressHextets)); EXPECT_FALSE(record.txt().IsEmpty()); } @@ -156,10 +158,10 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataOneAddress) { ASSERT_TRUE(result.is_value()); DnsSdInstanceRecord record = result.value(); - EXPECT_FALSE(record.address_v6().has_value()); - ASSERT_TRUE(record.address_v4().has_value()); - EXPECT_EQ(record.address_v4().value().port, port_num); - EXPECT_EQ(record.address_v4().value().address, v4_address); + EXPECT_FALSE(record.address_v6()); + ASSERT_TRUE(record.address_v4()); + EXPECT_EQ(record.address_v4().port, kServicePort); + EXPECT_EQ(record.address_v4().address, IPAddress(kV4AddressOctets)); // Address v6. data = CreateFullyPopulatedData(); @@ -168,10 +170,10 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataOneAddress) { ASSERT_TRUE(result.is_value()); record = result.value(); - EXPECT_FALSE(record.address_v4().has_value()); - ASSERT_TRUE(record.address_v6().has_value()); - EXPECT_EQ(record.address_v6().value().port, port_num); - EXPECT_EQ(record.address_v6().value().address, v6_address); + EXPECT_FALSE(record.address_v4()); + ASSERT_TRUE(record.address_v6()); + EXPECT_EQ(record.address_v6().port, kServicePort); + EXPECT_EQ(record.address_v6().address, IPAddress(kV6AddressHextets)); } TEST(DnsSdDnsDataTests, TestConvertDnsDataBadTxt) { @@ -183,19 +185,19 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataBadTxt) { // ApplyDataRecordChange tests. TEST(DnsSdDnsDataTests, TestApplyRecordChanges) { - MdnsRecord record = CreateFullyPopulatedRecord(port_num); - InstanceKey instance{instance_name, service_name, domain_name}; + MdnsRecord record = CreateFullyPopulatedRecord(kServicePort); + InstanceKey instance{kInstanceName, kServiceName, kDomainName}; DnsDataTesting data(instance); EXPECT_TRUE( data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok()); ASSERT_TRUE(data.srv().has_value()); - EXPECT_EQ(data.srv().value().port(), port_num); + EXPECT_EQ(data.srv().value().port(), kServicePort); record = CreateFullyPopulatedRecord(234); EXPECT_FALSE( data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok()); ASSERT_TRUE(data.srv().has_value()); - EXPECT_EQ(data.srv().value().port(), port_num); + EXPECT_EQ(data.srv().value().port(), kServicePort); record = CreateFullyPopulatedRecord(345); EXPECT_TRUE( diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc index 791faaaf7eb..5995768b2a1 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc @@ -8,25 +8,31 @@ #include "absl/strings/str_split.h" #include "discovery/dnssd/impl/conversion_layer.h" #include "discovery/dnssd/impl/service_key.h" +#include "discovery/dnssd/public/dns_sd_instance_record.h" #include "discovery/mdns/mdns_records.h" #include "discovery/mdns/public/mdns_constants.h" namespace openscreen { namespace discovery { -InstanceKey::InstanceKey(const MdnsRecord& record) { - ErrorOr<InstanceKey> key = CreateFromRecord(record); - OSP_DCHECK(key.is_value()); - *this = std::move(key.value()); +InstanceKey::InstanceKey(const MdnsRecord& record) + : InstanceKey(GetDomainName(record)) {} + +InstanceKey::InstanceKey(const DomainName& domain) + : ServiceKey(domain), instance_id_(domain.labels()[0]) { + OSP_DCHECK(IsInstanceValid(instance_id_)); } +InstanceKey::InstanceKey(const DnsSdInstanceRecord& record) + : InstanceKey(record.instance_id(), + record.service_id(), + record.domain_id()) {} + InstanceKey::InstanceKey(absl::string_view instance, absl::string_view service, absl::string_view domain) - : instance_id_(instance), service_id_(service), domain_id_(domain) { + : ServiceKey(service, domain), instance_id_(instance) { OSP_DCHECK(IsInstanceValid(instance_id_)); - OSP_DCHECK(IsServiceValid(service_id_)); - OSP_DCHECK(IsDomainValid(domain_id_)); } InstanceKey::InstanceKey(const InstanceKey& other) = default; @@ -35,41 +41,5 @@ InstanceKey::InstanceKey(InstanceKey&& other) = default; InstanceKey& InstanceKey::operator=(const InstanceKey& rhs) = default; InstanceKey& InstanceKey::operator=(InstanceKey&& rhs) = default; -bool InstanceKey::IsInstanceOf(const ServiceKey& service_key) const { - return service_id_ == service_key.service_id() && - domain_id_ == service_key.domain_id(); -} - -ErrorOr<InstanceKey> InstanceKey::CreateFromRecord(const MdnsRecord& record) { - const DomainName& names = - IsPtrRecord(record) - ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain() - : record.name(); - - if (names.labels().size() < 4) { - return Error::Code::kParameterInvalid; - } - - auto it = names.labels().begin(); - std::string instance_id = *it++; - if (!IsInstanceValid(instance_id)) { - return Error::Code::kParameterInvalid; - } - - std::string service_name = *it++; - std::string protocol = *it++; - std::string service_id = service_name.append(".").append(protocol); - if (!IsServiceValid(service_id)) { - return Error::Code::kParameterInvalid; - } - - std::string domain_id = absl::StrJoin(it, names.labels().end(), "."); - if (!IsDomainValid(domain_id)) { - return Error::Code::kParameterInvalid; - } - - return InstanceKey(instance_id, service_id, domain_id); -} - } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h index 36f4731285f..311b6f72e8a 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h @@ -9,27 +9,32 @@ #include <utility> #include "absl/strings/string_view.h" -#include "platform/base/error.h" +#include "discovery/dnssd/impl/service_key.h" namespace openscreen { namespace discovery { +class DnsSdInstanceRecord; +class DomainName; class MdnsRecord; class ServiceKey; // This class is intended to be used as the key of a std::unordered_map or // std::map when referencing data related to a specific service instance. -class InstanceKey { +class InstanceKey : public ServiceKey { public: // NOTE: The record provided must have valid instance, service, and domain // labels. explicit InstanceKey(const MdnsRecord& record); + explicit InstanceKey(const DomainName& domain); + explicit InstanceKey(const DnsSdInstanceRecord& record); // NOTE: The provided parameters must be valid instance, service and domain // ids. InstanceKey(absl::string_view instance, absl::string_view service, absl::string_view domain); + InstanceKey(const InstanceKey& other); InstanceKey(InstanceKey&& other); @@ -37,43 +42,29 @@ class InstanceKey { InstanceKey& operator=(InstanceKey&& rhs); const std::string& instance_id() const { return instance_id_; } - const std::string& service_id() const { return service_id_; } - const std::string& domain_id() const { return domain_id_; } - - // Represents whether this InstanceKey is an instance of the service provided. - bool IsInstanceOf(const ServiceKey& service_key) const; private: - static ErrorOr<InstanceKey> CreateFromRecord(const MdnsRecord& record); - std::string instance_id_; - std::string service_id_; - std::string domain_id_; template <typename H> friend H AbslHashValue(H h, const InstanceKey& key); friend bool operator<(const InstanceKey& lhs, const InstanceKey& rhs); - - // Validation method which needs the same code as CreateFromRecord(). Use a - // friend declaration to avoid duplicating this code while still keeping the - // factory private. - friend bool HasValidDnsRecordAddress(const MdnsRecord& record); }; template <typename H> H AbslHashValue(H h, const InstanceKey& key) { - return H::combine(std::move(h), key.service_id_, key.domain_id_, + return H::combine(std::move(h), key.service_id(), key.domain_id(), key.instance_id_); } inline bool operator<(const InstanceKey& lhs, const InstanceKey& rhs) { - int comp = lhs.domain_id_.compare(rhs.domain_id_); + int comp = lhs.domain_id().compare(rhs.domain_id()); if (comp != 0) { return comp < 0; } - comp = lhs.service_id_.compare(rhs.service_id_); + comp = lhs.service_id().compare(rhs.service_id()); if (comp != 0) { return comp < 0; } diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc index 25dd052bb4b..29e199817b1 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc @@ -45,20 +45,20 @@ TEST(DnsSdInstanceKeyTest, TestInstanceKeyEquals) { TEST(DnsSdInstanceKeyTest, TestIsInstanceOf) { ServiceKey ptr("_service._udp", "domain"); InstanceKey svc("instance", "_service._udp", "domain"); - EXPECT_TRUE(svc.IsInstanceOf(ptr)); + EXPECT_EQ(svc, ptr); svc = InstanceKey("other id", "_service._udp", "domain"); - EXPECT_TRUE(svc.IsInstanceOf(ptr)); + EXPECT_EQ(svc, ptr); svc = InstanceKey("instance", "_service._udp", "domain2"); - EXPECT_FALSE(svc.IsInstanceOf(ptr)); + EXPECT_FALSE(svc == ptr); ptr = ServiceKey("_service._udp", "domain2"); - EXPECT_TRUE(svc.IsInstanceOf(ptr)); + EXPECT_EQ(svc, ptr); svc = InstanceKey("instance", "_service2._udp", "domain"); - EXPECT_FALSE(svc.IsInstanceOf(ptr)); + EXPECT_NE(svc, ptr); ptr = ServiceKey("_service2._udp", "domain"); - EXPECT_TRUE(svc.IsInstanceOf(ptr)); + EXPECT_EQ(svc, ptr); } TEST(DnsSdInstanceKeyTest, InstanceKeyInMap) { diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc index c50c31d25d0..e8bdb9ea949 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc @@ -4,35 +4,219 @@ #include "discovery/dnssd/impl/publisher_impl.h" +#include <map> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "discovery/common/reporting_client.h" +#include "discovery/dnssd/impl/conversion_layer.h" +#include "discovery/dnssd/impl/instance_key.h" +#include "discovery/mdns/public/mdns_constants.h" +#include "platform/api/task_runner.h" #include "platform/base/error.h" namespace openscreen { namespace discovery { +namespace { -PublisherImpl::PublisherImpl(MdnsService* publisher) - : mdns_publisher_(publisher) {} +DnsSdInstanceRecord UpdateDomain(const DomainName& domain, + const DnsSdInstanceRecord& record) { + InstanceKey key(domain); + const IPEndpoint& v4 = record.address_v4(); + const IPEndpoint& v6 = record.address_v6(); + if (v4 && v6) { + return DnsSdInstanceRecord(key.instance_id(), key.service_id(), + key.domain_id(), v4, v6, record.txt()); + } else { + const IPEndpoint& endpoint = v4 ? v4 : v6; + return DnsSdInstanceRecord(key.instance_id(), key.service_id(), + key.domain_id(), endpoint, record.txt()); + } +} + +template <typename T> +inline typename std::map<DnsSdInstanceRecord, T>::iterator FindKey( + std::map<DnsSdInstanceRecord, T>* records, + const InstanceKey& key) { + return std::find_if(records->begin(), records->end(), + [&key](const std::pair<DnsSdInstanceRecord, T>& pair) { + return key == InstanceKey(pair.first); + }); +} + +template <typename T> +int EraseRecordsWithServiceId(std::map<DnsSdInstanceRecord, T>* records, + const std::string& service_id) { + int removed_count = 0; + for (auto it = records->begin(); it != records->end();) { + if (it->first.service_id() == service_id) { + removed_count++; + it = records->erase(it); + } else { + it++; + } + } + + return removed_count; +} + +} // namespace + +PublisherImpl::PublisherImpl(MdnsService* publisher, + ReportingClient* reporting_client, + TaskRunner* task_runner) + : mdns_publisher_(publisher), + reporting_client_(reporting_client), + task_runner_(task_runner) { + OSP_DCHECK(mdns_publisher_); + OSP_DCHECK(reporting_client_); + OSP_DCHECK(task_runner_); +} PublisherImpl::~PublisherImpl() = default; -Error PublisherImpl::Register(const DnsSdInstanceRecord& record) { - if (std::find(published_records_.begin(), published_records_.end(), record) != - published_records_.end()) { +Error PublisherImpl::Register(const DnsSdInstanceRecord& record, + Client* client) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + OSP_DCHECK(client != nullptr); + + if (published_records_.find(record) != published_records_.end()) { return Error::Code::kItemAlreadyExists; + } else if (pending_records_.find(record) != pending_records_.end()) { + return Error::Code::kOperationInProgress; } - published_records_.push_back(record); - for (const auto& mdns_record : GetDnsRecords(record)) { - mdns_publisher_->RegisterRecord(mdns_record); + InstanceKey key(record); + IPEndpoint endpoint = + record.address_v4() ? record.address_v4() : record.address_v6(); + pending_records_.emplace(record, client); + + OSP_DVLOG << "Registering instance '" << record.instance_id() << "'"; + + return mdns_publisher_->StartProbe(this, GetDomainName(key), + endpoint.address); +} + +Error PublisherImpl::UpdateRegistration(const DnsSdInstanceRecord& record) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + // Check if the record is still pending publication. + auto it = FindKey(&pending_records_, InstanceKey(record)); + + OSP_DVLOG << "Updating instance '" << record.instance_id() << "'"; + + // If it is a pending record, update it. Else, try to update a published + // record. + if (it != pending_records_.end()) { + // The instance, service, and domain ids have not changed, so only the + // remaining data needs to change. The ongoing probe does not need to be + // modified. + Client* const client = it->second; + pending_records_.erase(it); + pending_records_.emplace(record, client); + return Error::None(); + } else { + return UpdatePublishedRegistration(record); } - return Error::None(); } -size_t PublisherImpl::DeregisterAll(absl::string_view service) { - size_t removed_count = 0; +Error PublisherImpl::UpdatePublishedRegistration( + const DnsSdInstanceRecord& record) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + auto published_record_it = FindKey(&published_records_, InstanceKey(record)); + + // Check preconditions called out in header. Specifically, the updated record + // must be making changes to an already published record. + if (published_record_it == published_records_.end() || + published_record_it->first == record) { + return Error::Code::kParameterInvalid; + } + + // Get all records which have changed. By design, there an only be one record + // of each DnsType, so use that here to simplify this step. + // First in each pair is the old records, second is the new record. + std::map<DnsType, + std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>> + changed_records; + const std::vector<MdnsRecord> old_records = + GetDnsRecords(published_record_it->second); + const DnsSdInstanceRecord updated_record = UpdateDomain( + GetDomainName(InstanceKey(published_record_it->second)), record); + const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_record); + + // Populate the first part of each pair in |changed_records|. + for (size_t i = 0; i < old_records.size(); i++) { + const auto key = old_records[i].dns_type(); + OSP_DCHECK(changed_records.find(key) == changed_records.end()); + auto value = std::make_pair(std::move(old_records[i]), absl::nullopt); + changed_records.emplace(key, std::move(value)); + } + + // Populate the second part of each pair in |changed_records|. + for (size_t i = 0; i < new_records.size(); i++) { + const auto key = new_records[i].dns_type(); + auto find_it = changed_records.find(key); + if (find_it == changed_records.end()) { + std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value( + absl::nullopt, std::move(new_records[i])); + changed_records.emplace(key, std::move(value)); + } else { + find_it->second.second = std::move(new_records[i]); + } + } + + // Apply changes called out in |changed_records|. + // TODO(crbug.com/openscreen/114): Trace each below call so multiple errors + // can be seen. + Error total_result = Error::None(); + for (const auto& pair : changed_records) { + OSP_DCHECK(pair.second.first != absl::nullopt || + pair.second.second != absl::nullopt); + if (pair.second.first == absl::nullopt) { + auto error = mdns_publisher_->RegisterRecord(pair.second.second.value()); + if (!error.ok()) { + total_result = error; + } + } else if (pair.second.second == absl::nullopt) { + auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value()); + if (!error.ok()) { + total_result = error; + } + } else if (pair.second.first.value() != pair.second.second.value()) { + auto error = mdns_publisher_->UpdateRegisteredRecord( + pair.second.first.value(), pair.second.second.value()); + if (!error.ok()) { + total_result = error; + } + } + } + + // Replace the old records with the new ones. + published_records_.erase(published_record_it); + published_records_.emplace(record, std::move(updated_record)); + + return total_result; +} + +ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Deregistering all instances"; + + int removed_count = 0; + + // TODO(crbug.com/openscreen/114): Trace each below call so multiple errors + // can be seen. + Error error = Error::None(); for (auto it = published_records_.begin(); it != published_records_.end();) { - if (it->service_id() == service) { - for (const auto& mdns_record : GetDnsRecords(*it)) { - mdns_publisher_->DeregisterRecord(mdns_record); + if (it->second.service_id() == service) { + for (const auto& mdns_record : GetDnsRecords(it->second)) { + auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record); + if (!publisher_error.ok()) { + error = publisher_error; + } } removed_count++; it = published_records_.erase(it); @@ -41,7 +225,54 @@ size_t PublisherImpl::DeregisterAll(absl::string_view service) { } } - return removed_count; + removed_count += EraseRecordsWithServiceId(&pending_records_, service); + + if (!error.ok()) { + return error; + } else { + return removed_count; + } +} + +void PublisherImpl::OnDomainFound(const DomainName& requested_name, + const DomainName& confirmed_name) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString() + << "' based on requested name: '" << requested_name.ToString() + << "'"; + + auto it = FindKey(&pending_records_, InstanceKey(requested_name)); + + if (it == pending_records_.end()) { + // This will be hit if the record was deregistered before the probe phase + // was completed. + return; + } + + DnsSdInstanceRecord requested_record = std::move(it->first); + DnsSdInstanceRecord publication = requested_record; + Client* const client = it->second; + pending_records_.erase(it); + + InstanceKey requested_key(requested_record); + + if (requested_name != confirmed_name) { + OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name)); + publication = UpdateDomain(confirmed_name, requested_record); + } + + for (const auto& mdns_record : GetDnsRecords(publication)) { + Error result = mdns_publisher_->RegisterRecord(mdns_record); + if (!result.ok()) { + reporting_client_->OnRecoverableError( + Error(Error::Code::kRecordPublicationError, result.ToString())); + } + } + + auto pair = published_records_.emplace(std::move(requested_record), + std::move(publication)); + client->OnInstanceClaimed(pair.first->first, pair.first->second); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h index 65d3235db75..5383bdde434 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h @@ -9,24 +9,46 @@ #include "discovery/dnssd/impl/conversion_layer.h" #include "discovery/dnssd/public/dns_sd_instance_record.h" #include "discovery/dnssd/public/dns_sd_publisher.h" +#include "discovery/mdns/mdns_domain_confirmed_provider.h" #include "discovery/mdns/public/mdns_service.h" namespace openscreen { namespace discovery { -class PublisherImpl : public DnsSdPublisher { +class ReportingClient; + +class PublisherImpl : public DnsSdPublisher, + public MdnsDomainConfirmedProvider { public: - PublisherImpl(MdnsService* publisher); + PublisherImpl(MdnsService* publisher, + ReportingClient* reporting_client, + TaskRunner* task_runner); ~PublisherImpl() override; // DnsSdPublisher overrides. - Error Register(const DnsSdInstanceRecord& record) override; - size_t DeregisterAll(absl::string_view service) override; + Error Register(const DnsSdInstanceRecord& record, Client* client) override; + Error UpdateRegistration(const DnsSdInstanceRecord& record) override; + ErrorOr<int> DeregisterAll(const std::string& service) override; private: - std::vector<DnsSdInstanceRecord> published_records_; + Error UpdatePublishedRegistration(const DnsSdInstanceRecord& record); + + // MdnsDomainConfirmedProvider overrides. + void OnDomainFound(const DomainName& requested_name, + const DomainName& confirmed_name) override; + + // The set of records which will be published once the mDNS Probe phase + // completes. + std::map<DnsSdInstanceRecord, Client* const> pending_records_; + + // Maps from the requested record to the record which was published after + // the mDNS Probe phase was completed. The only difference between these + // records should be the instance name. + std::map<DnsSdInstanceRecord, DnsSdInstanceRecord> published_records_; MdnsService* const mdns_publisher_; + ReportingClient* const reporting_client_; + TaskRunner* const task_runner_; friend class PublisherTesting; }; diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc index fd30c567198..70c4d45425a 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc @@ -4,16 +4,28 @@ #include "discovery/dnssd/impl/publisher_impl.h" +#include <utility> #include <vector> +#include "discovery/common/testing/mock_reporting_client.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" namespace openscreen { namespace discovery { +namespace { using testing::_; using testing::Return; +using testing::StrictMock; + +class MockClient : public DnsSdPublisher::Client { + public: + MOCK_METHOD2(OnInstanceClaimed, + void(const DnsSdInstanceRecord&, const DnsSdInstanceRecord&)); +}; class MockMdnsService : public MdnsService { public: @@ -21,40 +33,71 @@ class MockMdnsService : public MdnsService { DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) override { - OSP_UNIMPLEMENTED(); + FAIL(); } void StopQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) override { - OSP_UNIMPLEMENTED(); + FAIL(); } - MOCK_METHOD1(RegisterRecord, void(const MdnsRecord& record)); - MOCK_METHOD1(DeregisterRecord, void(const MdnsRecord& record)); + void ReinitializeQueries(const DomainName& name) override { FAIL(); } + + MOCK_METHOD3(StartProbe, + Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress)); + MOCK_METHOD2(UpdateRegisteredRecord, + Error(const MdnsRecord&, const MdnsRecord&)); + MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord& record)); + MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord& record)); }; -class PublisherTesting : public PublisherImpl { +class PublisherImplTest : public testing::Test { public: - PublisherTesting() : PublisherImpl(&mock_service_) {} + PublisherImplTest() + : clock_(Clock::now()), + task_runner_(&clock_), + publisher_(&mock_service_, &reporting_client_, &task_runner_) {} + + MockMdnsService* mdns_service() { return &mock_service_; } + TaskRunner* task_runner() { return &task_runner_; } + PublisherImpl* publisher() { return &publisher_; } - MockMdnsService& mdns_service() { return mock_service_; } + // Calls PublisherImpl::OnDomainFound() through the public interface it + // implements. + void CallOnDomainFound(const DomainName& domain, const DomainName& domain2) { + static_cast<MdnsDomainConfirmedProvider&>(publisher_) + .OnDomainFound(domain, domain2); + } private: - MockMdnsService mock_service_; + FakeClock clock_; + FakeTaskRunner task_runner_; + StrictMock<MockMdnsService> mock_service_; + StrictMock<MockReportingClient> reporting_client_; + PublisherImpl publisher_; }; -TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) { - PublisherTesting publisher; - IPAddress address = IPAddress(std::array<uint8_t, 4>{{192, 168, 0, 0}}); - DnsSdInstanceRecord record("instance", "_service._udp", "domain", - {address, 80}, {}); +TEST_F(PublisherImplTest, TestRegistrationAndDegrestration) { + IPAddress address = IPAddress(192, 168, 0, 0); + const DomainName domain{"instance", "_service", "_udp", "domain"}; + const DomainName domain2{"instance2", "_service", "_udp", "domain"}; + const DnsSdInstanceRecord record("instance", "_service._udp", "domain", + {address, 80}, {}); + const DnsSdInstanceRecord record2("instance2", "_service._udp", "domain", + {address, 80}, {}); + MockClient client; + + EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1); + publisher()->Register(record, &client); + testing::Mock::VerifyAndClearExpectations(mdns_service()); int seen = 0; - EXPECT_CALL(publisher.mdns_service(), RegisterRecord(_)) + EXPECT_CALL(*mdns_service(), RegisterRecord(_)) .Times(4) - .WillRepeatedly([&seen, &address](const MdnsRecord& record) mutable { + .WillRepeatedly([&seen, &address, + &domain2](const MdnsRecord& record) mutable -> Error { if (record.dns_type() == DnsType::kA) { const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata()); if (data.ipv4_address() == address) { @@ -66,15 +109,24 @@ TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) { if (data.port() == 80) { seen++; } - }; + } + + if (record.dns_type() != DnsType::kPTR) { + EXPECT_EQ(record.name(), domain2); + } + return Error::None(); }); - publisher.Register(record); + EXPECT_CALL(client, OnInstanceClaimed(record, record2)); + CallOnDomainFound(domain, domain2); EXPECT_EQ(seen, 2); + testing::Mock::VerifyAndClearExpectations(mdns_service()); + testing::Mock::VerifyAndClearExpectations(&client); seen = 0; - EXPECT_CALL(publisher.mdns_service(), DeregisterRecord(_)) + EXPECT_CALL(*mdns_service(), UnregisterRecord(_)) .Times(4) - .WillRepeatedly([&seen, &address](const MdnsRecord& record) mutable { + .WillRepeatedly([&seen, + &address](const MdnsRecord& record) mutable -> Error { if (record.dns_type() == DnsType::kA) { const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata()); if (data.ipv4_address() == address) { @@ -86,11 +138,75 @@ TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) { if (data.port() == 80) { seen++; } - }; + } + return Error::None(); }); - publisher.DeregisterAll("_service._udp"); + publisher()->DeregisterAll("_service._udp"); EXPECT_EQ(seen, 2); } +TEST_F(PublisherImplTest, TestUpdate) { + IPAddress address = IPAddress(192, 168, 0, 0); + DomainName domain{"instance", "_service", "_udp", "domain"}; + DnsSdTxtRecord txt; + txt.SetFlag("id", true); + DnsSdInstanceRecord record("instance", "_service._udp", "domain", + {address, 80}, std::move(txt)); + MockClient client; + + // Update a non-existent record + EXPECT_FALSE(publisher()->UpdateRegistration(record).ok()); + + // Update a record during the probing phase + EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1); + EXPECT_EQ(publisher()->Register(record, &client), Error::None()); + testing::Mock::VerifyAndClearExpectations(mdns_service()); + + IPAddress address2 = IPAddress(1, 2, 3, 4, 5, 6, 7, 8); + DnsSdTxtRecord txt2; + txt2.SetFlag("id2", true); + DnsSdInstanceRecord record2("instance", "_service._udp", "domain", + {address2, 80}, std::move(txt2)); + EXPECT_EQ(publisher()->UpdateRegistration(record2), Error::None()); + + bool seen_v6 = false; + EXPECT_CALL(*mdns_service(), RegisterRecord(_)) + .Times(4) + .WillRepeatedly([&seen_v6](const MdnsRecord& record) mutable -> Error { + EXPECT_NE(record.dns_type(), DnsType::kA); + if (record.dns_type() == DnsType::kAAAA) { + seen_v6 = true; + } + return Error::None(); + }); + EXPECT_CALL(client, OnInstanceClaimed(record2, record2)); + CallOnDomainFound(domain, domain); + EXPECT_TRUE(seen_v6); + testing::Mock::VerifyAndClearExpectations(mdns_service()); + testing::Mock::VerifyAndClearExpectations(&client); + + // Update a record once it has been published. + EXPECT_CALL(*mdns_service(), RegisterRecord(_)) + .WillOnce([](const MdnsRecord& record) -> Error { + EXPECT_EQ(record.dns_type(), DnsType::kA); + return Error::None(); + }); + EXPECT_CALL(*mdns_service(), UnregisterRecord(_)) + .WillOnce([](const MdnsRecord& record) -> Error { + EXPECT_EQ(record.dns_type(), DnsType::kAAAA); + return Error::None(); + }); + EXPECT_CALL(*mdns_service(), UpdateRegisteredRecord(_, _)) + .WillOnce( + [](const MdnsRecord& record, const MdnsRecord& record2) -> Error { + EXPECT_EQ(record.dns_type(), DnsType::kTXT); + EXPECT_EQ(record2.dns_type(), DnsType::kTXT); + return Error::None(); + }); + EXPECT_EQ(publisher()->UpdateRegistration(record), Error::None()); + testing::Mock::VerifyAndClearExpectations(mdns_service()); +} + +} // namespace } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc index 79f899ecfa4..37f4d59d0d7 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc @@ -7,6 +7,7 @@ #include <string> #include <vector> +#include "platform/api/task_runner.h" #include "util/logging.h" namespace openscreen { @@ -17,47 +18,50 @@ static constexpr char kLocalDomain[] = "local"; } // namespace -QuerierImpl::QuerierImpl(MdnsService* mdns_querier) - : mdns_querier_(mdns_querier) { +QuerierImpl::QuerierImpl(MdnsService* mdns_querier, TaskRunner* task_runner) + : mdns_querier_(mdns_querier), task_runner_(task_runner) { OSP_DCHECK(mdns_querier_); + OSP_DCHECK(task_runner_); } QuerierImpl::~QuerierImpl() = default; -void QuerierImpl::StartQuery(absl::string_view service, Callback* callback) { +void QuerierImpl::StartQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Starting query for service '" << service << "'"; ServiceKey key(service, kLocalDomain); - DnsQueryInfo query = GetPtrQueryInfo(key); if (!IsQueryRunning(key)) { - callback_map_[key] = {}; - StartDnsQuery(query); + callback_map_[key] = {callback}; + StartDnsQuery(std::move(key)); } else { - const std::vector<InstanceKey> keys = GetMatchingInstances(key); - for (const auto& key : keys) { - auto it = received_records_.find(key); - if (it == received_records_.end()) { - continue; - } - - ErrorOr<DnsSdInstanceRecord> record = it->second.CreateRecord(); - if (record.is_value()) { - callback->OnInstanceCreated(record.value()); + callback_map_[key].push_back(callback); + + for (auto& kvp : received_records_) { + if (kvp.first == key) { + ErrorOr<DnsSdInstanceRecord> record = kvp.second.CreateRecord(); + if (record.is_value()) { + callback->OnInstanceCreated(record.value()); + } } } } - callback_map_[key].push_back(callback); } -bool QuerierImpl::IsQueryRunning(absl::string_view service) const { +bool QuerierImpl::IsQueryRunning(const std::string& service) const { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); return IsQueryRunning(ServiceKey(service, kLocalDomain)); } -void QuerierImpl::StopQuery(absl::string_view service, Callback* callback) { +void QuerierImpl::StopQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Stopping query for service '" << service << "'"; ServiceKey key(service, kLocalDomain); - DnsQueryInfo query = GetPtrQueryInfo(key); auto callback_it = callback_map_.find(key); if (callback_it == callback_map_.end()) { return; @@ -68,15 +72,42 @@ void QuerierImpl::StopQuery(absl::string_view service, Callback* callback) { if (it != callbacks->end()) { callbacks->erase(it); if (callbacks->empty()) { - EraseInstancesOf(key); callback_map_.erase(callback_it); - StopDnsQuery(query); + StopDnsQuery(std::move(key)); } } } +void QuerierImpl::ReinitializeQueries(const std::string& service) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Re-initializing query for service '" << service << "'"; + + const ServiceKey key(service, kLocalDomain); + + // Stop instance-specific queries and erase all instance data received so far. + std::vector<InstanceKey> keys_to_remove; + for (const auto& pair : received_records_) { + if (key == pair.first) { + keys_to_remove.push_back(pair.first); + } + } + for (InstanceKey& ik : keys_to_remove) { + StopDnsQuery(std::move(ik), false); + } + + // Restart top-level queries. + mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name); +} + void QuerierImpl::OnRecordChanged(const MdnsRecord& record, RecordChangedEvent event) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + OSP_DVLOG << "Record with name '" << record.name().ToString() + << "' and type '" << record.dns_type() + << "' has received change of type '" << event << "'"; + IsPtrRecord(record) ? HandlePtrRecordChange(record, event) : HandleNonPtrRecordChange(record, event); } @@ -88,15 +119,12 @@ Error QuerierImpl::HandlePtrRecordChange(const MdnsRecord& record, return Error::Code::kParameterInvalid; } - InstanceKey key(record); - DnsQueryInfo query = GetInstanceQueryInfo(key); switch (event) { case RecordChangedEvent::kCreated: - StartDnsQuery(query); + StartDnsQuery(InstanceKey(record)); return Error::None(); case RecordChangedEvent::kExpired: - StopDnsQuery(query); - EraseInstancesOf(ServiceKey(key)); + StopDnsQuery(InstanceKey(record)); return Error::None(); case RecordChangedEvent::kUpdated: return Error::Code::kOperationInvalid; @@ -165,54 +193,65 @@ void QuerierImpl::NotifyCallbacks( } } -void QuerierImpl::EraseInstancesOf(const ServiceKey& key) { - std::vector<InstanceKey> keys = GetMatchingInstances(key); - const auto it = callback_map_.find(key); - std::vector<Callback*> callbacks; - if (it != callback_map_.end()) { - callbacks = it->second; +void QuerierImpl::StartDnsQuery(InstanceKey key) { + auto pair = received_records_.emplace(key, DnsData(key)); + if (!pair.second) { + // This means that a query is already ongoing. + return; } - for (const auto& key : keys) { - auto recieved_record = received_records_.find(key); - if (recieved_record == received_records_.end()) { - continue; - } + DnsQueryInfo query = GetInstanceQueryInfo(key); + mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this); +} + +void QuerierImpl::StopDnsQuery(InstanceKey key, bool should_inform_callbacks) { + // If the instance is not being queried for, return. + auto record_it = received_records_.find(key); + if (record_it == received_records_.end()) { + return; + } - ErrorOr<DnsSdInstanceRecord> instance_record = - recieved_record->second.CreateRecord(); - if (instance_record.is_value()) { - for (Callback* callback : callbacks) { + // If the instance has enough associated data that an instance was provided to + // the higher layer, call the deleted callback for all associated callbacks. + ErrorOr<DnsSdInstanceRecord> instance_record = + record_it->second.CreateRecord(); + if (should_inform_callbacks && instance_record.is_value()) { + const auto it = callback_map_.find(key); + if (it != callback_map_.end()) { + for (Callback* callback : it->second) { callback->OnInstanceDeleted(instance_record.value()); } } - - received_records_.erase(recieved_record); } -} -std::vector<InstanceKey> QuerierImpl::GetMatchingInstances( - const ServiceKey& key) { - // Because only one or two PTR queries are expected at a time, expect >=1/2 of - // the records to be associated with a given PTR. They can't be removed in - // less than O(n) time, so just iterate across them all. - std::vector<InstanceKey> keys; - for (auto it = received_records_.begin(); it != received_records_.end(); - it++) { - if (it->first.IsInstanceOf(key)) { - keys.push_back(it->first); - } - } + // Erase the key to mark the instance as no longer being queried for. + received_records_.erase(record_it); - return keys; + // Call to the mDNS layer to stop the query. + DnsQueryInfo query = GetInstanceQueryInfo(key); + mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this); } -void QuerierImpl::StartDnsQuery(const DnsQueryInfo& query) { +void QuerierImpl::StartDnsQuery(ServiceKey key) { + DnsQueryInfo query = GetPtrQueryInfo(key); mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this); } -void QuerierImpl::StopDnsQuery(const DnsQueryInfo& query) { +void QuerierImpl::StopDnsQuery(ServiceKey key) { + DnsQueryInfo query = GetPtrQueryInfo(key); mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this); + + // Stop any ongoing instance-specific queries. + std::vector<InstanceKey> keys_to_remove; + for (const auto& pair : received_records_) { + const bool key_is_service_from_query = (key == pair.first); + if (key_is_service_from_query) { + keys_to_remove.push_back(pair.first); + } + } + for (auto it = keys_to_remove.begin(); it != keys_to_remove.end(); it++) { + StopDnsQuery(std::move(*it)); + } } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h index a39c91b58e9..f64fc1ada58 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h @@ -27,19 +27,19 @@ namespace discovery { class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { public: - // |querier| must outlive the QuerierImpl instance constructed. - explicit QuerierImpl(MdnsService* querier); + // |querier| and |task_runner| must outlive the QuerierImpl instance + // constructed. + QuerierImpl(MdnsService* querier, TaskRunner* task_runner); ~QuerierImpl() override; - bool IsQueryRunning(absl::string_view service) const; + bool IsQueryRunning(const std::string& service) const; // DnsSdQuerier overrides. - void StartQuery(absl::string_view service, Callback* callback) override; - void StopQuery(absl::string_view service, Callback* callback) override; + void StartQuery(const std::string& service, Callback* callback) override; + void StopQuery(const std::string& service, Callback* callback) override; + void ReinitializeQueries(const std::string& service) override; // MdnsRecordChangedCallback overrides. - // TODO(rwkeane): Ensure this is run on the TaskRunner thread once the - // underlying mDNS implementation can be overridden. void OnRecordChanged(const MdnsRecord& record, RecordChangedEvent event) override; @@ -58,12 +58,10 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { } // Initiates or terminates queries on the mdns_querier_ object. - void StartDnsQuery(const DnsQueryInfo& query); - void StopDnsQuery(const DnsQueryInfo& query); - - // Erases all instance records describing services matching the provided key - // and informs all callbacks associated with the given key of their deletion. - void EraseInstancesOf(const ServiceKey& service); + void StartDnsQuery(InstanceKey key); + void StartDnsQuery(ServiceKey key); + void StopDnsQuery(InstanceKey key, bool should_inform_callbacks = true); + void StopDnsQuery(ServiceKey key); // Calls the appropriate callback method based on the provided Instance Record // values. @@ -71,12 +69,12 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { const ErrorOr<DnsSdInstanceRecord>& old_record, const ErrorOr<DnsSdInstanceRecord>& new_record); - // Returns all InstanceKeys received so far which represent instances of - // the service described by the provided ServiceKey. - std::vector<InstanceKey> GetMatchingInstances(const ServiceKey& key); - // Map from a specific service instance to the data received so far about - // that instance. + // that instance. The keys in this map are the instances for which an + // associated PTR record has been received, and the values are the set of + // non-PTR records received which describe that service (if any). Note that, + // with this definition, it is possible for a InstanceKey to be mapped to an + // empty DnsData if the instance has no associated records yet. std::unordered_map<InstanceKey, DnsData, absl::Hash<InstanceKey>> received_records_; @@ -85,6 +83,7 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { std::map<ServiceKey, std::vector<Callback*>> callback_map_; MdnsService* const mdns_querier_; + TaskRunner* const task_runner_; friend class QuerierImplTesting; }; diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc index 327eeccc8f3..9da8a4543dc 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc @@ -5,6 +5,8 @@ #include "discovery/dnssd/impl/querier_impl.h" #include <memory> +#include <string> +#include <utility> #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -13,6 +15,8 @@ #include "discovery/mdns/testing/mdns_test_util.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" #include "util/logging.h" namespace openscreen { @@ -36,9 +40,15 @@ class MockMdnsService : public MdnsService { StopQuery, void(const DomainName&, DnsType, DnsClass, MdnsRecordChangedCallback*)); - void RegisterRecord(const MdnsRecord& record) override { FAIL(); } + MOCK_METHOD1(ReinitializeQueries, void(const DomainName& name)); - void DeregisterRecord(const MdnsRecord& record) override { FAIL(); } + // Unused. + MOCK_METHOD3(StartProbe, + Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress)); + MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord&)); + MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord&)); + MOCK_METHOD2(UpdateRegisteredRecord, + Error(const MdnsRecord&, const MdnsRecord&)); }; SrvRecordRdata CreateSrvRecord() { @@ -52,8 +62,8 @@ ARecordRdata CreateARecord() { } AAAARecordRdata CreateAAAARecord() { - return AAAARecordRdata( - IPAddress{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + return AAAARecordRdata(IPAddress(0x0102, 0x0304, 0x0506, 0x0708, 0x090a, + 0x0b0c, 0x0d0e, 0x0f10)); } MdnsRecord CreatePtrRecord(const std::string& instance, @@ -94,7 +104,7 @@ using testing::StrictMock; class DnsDataAccessor { public: - DnsDataAccessor(DnsData* data) : data_(data) {} + explicit DnsDataAccessor(DnsData* data) : data_(data) {} void set_srv(absl::optional<SrvRecordRdata> record) { data_->srv_ = record; } void set_txt(absl::optional<TxtRecordRdata> record) { data_->txt_ = record; } @@ -116,7 +126,10 @@ class DnsDataAccessor { class QuerierImplTesting : public QuerierImpl { public: - QuerierImplTesting() : QuerierImpl(&mock_service_) {} + QuerierImplTesting() + : QuerierImpl(&mock_service_, &task_runner_), + clock_(Clock::now()), + task_runner_(&clock_) {} MockMdnsService* service() { return &mock_service_; } @@ -140,6 +153,8 @@ class QuerierImplTesting : public QuerierImpl { } private: + FakeClock clock_; + FakeTaskRunner task_runner_; StrictMock<MockMdnsService> mock_service_; }; @@ -153,6 +168,7 @@ class DnsSdQuerierImplTest : public testing::Test { .Times(1); querier.StartQuery(service, &callback); EXPECT_TRUE(querier.IsQueryRunning(service)); + testing::Mock::VerifyAndClearExpectations(querier.service()); EXPECT_CALL(*querier.service(), StartQuery(_, DnsType::kPTR, DnsClass::kANY, _)) @@ -201,6 +217,9 @@ TEST_F(DnsSdQuerierImplTest, TestStopQueryClearsRecords) { EXPECT_CALL(*querier.service(), StopQuery(_, DnsType::kPTR, DnsClass::kANY, _)) .Times(1); + EXPECT_CALL(*querier.service(), + StopQuery(_, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); querier.StopQuery(service, &callback); EXPECT_FALSE(querier.GetDnsData(instance, service, domain).has_value()); } @@ -221,6 +240,7 @@ TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) { StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) .Times(1); querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated); + testing::Mock::VerifyAndClearExpectations(querier.service()); EXPECT_CALL(*querier.service(), StopQuery(_, DnsType::kANY, DnsClass::kANY, _)) @@ -230,6 +250,12 @@ TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) { TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) { auto ptr = CreatePtrRecord(instance, service, domain); + EXPECT_CALL(*querier.service(), + StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated); + testing::Mock::VerifyAndClearExpectations(querier.service()); + DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); dns_data.set_srv(CreateSrvRecord()); dns_data.set_txt(MakeTxtRecord({})); @@ -237,11 +263,6 @@ TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) { dns_data.set_aaaa(CreateAAAARecord()); ASSERT_TRUE(dns_data.CanCreateInstance()); - EXPECT_CALL(*querier.service(), - StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) - .Times(1); - querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated); - EXPECT_CALL(callback, OnInstanceDeleted(_)).Times(1); EXPECT_CALL(*querier.service(), StopQuery(_, DnsType::kANY, DnsClass::kANY, _)) @@ -276,9 +297,11 @@ TEST_F(DnsSdQuerierImplTest, BothNewAndOldValidRecords) { EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1); querier.OnRecordChanged(a_record, RecordChangedEvent::kCreated); + testing::Mock::VerifyAndClearExpectations(&callback); EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1); querier.OnRecordChanged(a_record, RecordChangedEvent::kUpdated); + testing::Mock::VerifyAndClearExpectations(&callback); auto aaaa_rdata = CreateAAAARecord(); MdnsRecord aaaa_record(kDomainName, DnsType::kAAAA, DnsClass::kIN, @@ -287,9 +310,11 @@ TEST_F(DnsSdQuerierImplTest, BothNewAndOldValidRecords) { EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1); querier.OnRecordChanged(aaaa_record, RecordChangedEvent::kUpdated); + testing::Mock::VerifyAndClearExpectations(&callback); EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1); querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired); + testing::Mock::VerifyAndClearExpectations(&callback); } TEST_F(DnsSdQuerierImplTest, OnlyNewRecordValid) { @@ -321,5 +346,35 @@ TEST_F(DnsSdQuerierImplTest, OnlyOldRecordValid) { querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired); } +TEST_F(DnsSdQuerierImplTest, HardRefresh) { + const std::string service2 = "_service2._udp"; + + DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); + dns_data.set_srv(CreateSrvRecord()); + dns_data.set_txt(MakeTxtRecord({})); + dns_data.set_a(CreateARecord()); + dns_data.set_aaaa(CreateAAAARecord()); + DnsDataAccessor dns_data2 = querier.CreateDnsData(instance, service2, domain); + dns_data2.set_srv(CreateSrvRecord()); + + EXPECT_CALL(callback, OnInstanceCreated(_)).Times(1); + querier.StartQuery(service, &callback); + EXPECT_TRUE(querier.IsQueryRunning(service)); + + const DomainName ptr_domain{"_service", "_udp", "local"}; + const DomainName instance_domain{"instance", "_service", "_udp", "local"}; + EXPECT_CALL(*querier.service(), ReinitializeQueries(ptr_domain)); + EXPECT_CALL(*querier.service(), StopQuery(instance_domain, _, _, _)); + querier.ReinitializeQueries(service); + testing::Mock::VerifyAndClearExpectations(querier.service()); + + absl::optional<DnsDataAccessor> data = + querier.GetDnsData(instance, service, domain); + EXPECT_EQ(data, absl::nullopt); + data = querier.GetDnsData(instance, service2, domain); + EXPECT_NE(data, absl::nullopt); + EXPECT_TRUE(querier.IsQueryRunning(service)); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc index e1caaabb049..ba8ce1518db 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc @@ -6,22 +6,60 @@ #include <utility> +#include "discovery/common/config.h" #include "discovery/mdns/public/mdns_service.h" +#include "platform/api/task_runner.h" namespace openscreen { namespace discovery { +namespace { + +MdnsService::SupportedNetworkAddressFamily GetSupportedEndpointTypes( + const InterfaceInfo& interface) { + MdnsService::SupportedNetworkAddressFamily supported_types = + MdnsService::kNoAddressFamily; + if (interface.GetIpAddressV4()) { + supported_types = supported_types | MdnsService::kUseIpV4Multicast; + } + if (interface.GetIpAddressV6()) { + supported_types = supported_types | MdnsService::kUseIpV6Multicast; + } + return supported_types; +} + +} // namespace // static -std::unique_ptr<DnsSdService> DnsSdService::Create(TaskRunner* task_runner) { - return std::make_unique<ServiceImpl>(task_runner); +SerialDeletePtr<DnsSdService> CreateDnsSdService( + TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config) { + return SerialDeletePtr<DnsSdService>( + task_runner, new ServiceImpl(task_runner, reporting_client, config)); } -ServiceImpl::ServiceImpl(TaskRunner* task_runner) - : mdns_service_(MdnsService::Create(task_runner)), - querier_(mdns_service_.get()), - publisher_(mdns_service_.get()) {} +ServiceImpl::ServiceImpl(TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config) + : task_runner_(task_runner), + mdns_service_( + MdnsService::Create(task_runner, + reporting_client, + config, + config.interface.index, + GetSupportedEndpointTypes(config.interface))) { + if (config.enable_querying) { + querier_ = std::make_unique<QuerierImpl>(mdns_service_.get(), task_runner_); + } + if (config.enable_publication) { + publisher_ = std::make_unique<PublisherImpl>( + mdns_service_.get(), reporting_client, task_runner_); + } +} -ServiceImpl::~ServiceImpl() = default; +ServiceImpl::~ServiceImpl() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); +} } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h index 829cceb9284..3640407e46d 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h @@ -10,32 +10,34 @@ #include "discovery/dnssd/impl/publisher_impl.h" #include "discovery/dnssd/impl/querier_impl.h" #include "discovery/dnssd/public/dns_sd_service.h" +#include "platform/base/interface_info.h" namespace openscreen { -namespace platform { class TaskRunner; -} - namespace discovery { class MdnsService; class ServiceImpl final : public DnsSdService { public: - explicit ServiceImpl(TaskRunner* task_runner); + ServiceImpl(TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config); ~ServiceImpl() override; // DnsSdService overrides. - DnsSdQuerier* Querier() override { return &querier_; } - DnsSdPublisher* Publisher() override { return &publisher_; } + DnsSdQuerier* GetQuerier() override { return querier_.get(); } + DnsSdPublisher* GetPublisher() override { return publisher_.get(); } private: + TaskRunner* const task_runner_; + std::unique_ptr<MdnsService> mdns_service_; - QuerierImpl querier_; - PublisherImpl publisher_; + std::unique_ptr<QuerierImpl> querier_; + std::unique_ptr<PublisherImpl> publisher_; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc index 5172ec4a9f9..e5af9d3ba89 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc @@ -7,7 +7,6 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "discovery/dnssd/impl/conversion_layer.h" -#include "discovery/dnssd/impl/instance_key.h" #include "discovery/mdns/mdns_records.h" #include "discovery/mdns/public/mdns_constants.h" @@ -17,15 +16,18 @@ namespace discovery { // The InstanceKey ctor used below cares about the Instance ID of the // MdnsRecord, while this class doesn't, so it's possible that the InstanceKey // ctor will fail for this reason. That's not a problem though - A failure when -// creating an InstanceKey would mean that the record we recieved was invalid, +// creating an InstanceKey would mean that the record we received was invalid, // so there is no reason to continue processing. -ServiceKey::ServiceKey(const MdnsRecord& record) - : ServiceKey(InstanceKey(record)) {} +ServiceKey::ServiceKey(const MdnsRecord& record) { + ErrorOr<ServiceKey> key = TryCreate(record); + OSP_DCHECK(key.is_value()); + *this = std::move(key.value()); +} -ServiceKey::ServiceKey(const InstanceKey& key) - : ServiceKey(key.service_id(), key.domain_id()) { - OSP_DCHECK(IsServiceValid(service_id_)); - OSP_DCHECK(IsDomainValid(domain_id_)); +ServiceKey::ServiceKey(const DomainName& domain) { + ErrorOr<ServiceKey> key = TryCreate(domain); + OSP_DCHECK(key.is_value()); + *this = std::move(key.value()); } ServiceKey::ServiceKey(absl::string_view service, absl::string_view domain) @@ -41,5 +43,36 @@ ServiceKey::ServiceKey(ServiceKey&& other) = default; ServiceKey& ServiceKey::operator=(const ServiceKey& rhs) = default; ServiceKey& ServiceKey::operator=(ServiceKey&& rhs) = default; +// static +ErrorOr<ServiceKey> ServiceKey::TryCreate(const MdnsRecord& record) { + return TryCreate(GetDomainName(record)); +} + +// static +ErrorOr<ServiceKey> ServiceKey::TryCreate(const DomainName& names) { + // Size must be at least 4, because the minimum valid label is of the form + // <instance>.<service type>.<protocol>.<domain> + if (names.labels().size() < 4) { + return Error::Code::kParameterInvalid; + } + + // Skip the InstanceId. + auto it = ++names.labels().begin(); + + std::string service_name = *it++; + const std::string protocol = *it++; + const std::string service_id = service_name.append(".").append(protocol); + if (!IsServiceValid(service_id)) { + return Error::Code::kParameterInvalid; + } + + const std::string domain_id = absl::StrJoin(it, names.labels().end(), "."); + if (!IsDomainValid(domain_id)) { + return Error::Code::kParameterInvalid; + } + + return ServiceKey(service_id, domain_id); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h index bdb5e0ee003..4e6ff7fae8d 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h @@ -9,24 +9,24 @@ #include <utility> #include "absl/strings/string_view.h" +#include "platform/base/error.h" namespace openscreen { namespace discovery { -class InstanceKey; +class DomainName; class MdnsRecord; // This class is intended to be used as the key of a std::unordered_map or a // std::map when referencing data related to a service type class ServiceKey { public: - // NOTE: The record provided must have valid service, domain, and instance - // labels. + // NOTE: The record provided must have valid service domain labels. explicit ServiceKey(const MdnsRecord& record); + explicit ServiceKey(const DomainName& domain); // NOTE: The provided service and domain labels must be valid. ServiceKey(absl::string_view service, absl::string_view domain); - explicit ServiceKey(const InstanceKey& key); ServiceKey(const ServiceKey& other); ServiceKey(ServiceKey&& other); @@ -37,6 +37,9 @@ class ServiceKey { const std::string& domain_id() const { return domain_id_; } private: + static ErrorOr<ServiceKey> TryCreate(const MdnsRecord& record); + static ErrorOr<ServiceKey> TryCreate(const DomainName& domain); + std::string service_id_; std::string domain_id_; @@ -44,6 +47,12 @@ class ServiceKey { friend H AbslHashValue(H h, const ServiceKey& key); friend bool operator<(const ServiceKey& lhs, const ServiceKey& rhs); + + // Validation method which needs the same code as CreateFromRecord(). Use a + // friend declaration to avoid duplicating this code while still keeping the + // factory private. + friend bool HasValidDnsRecordAddress(const MdnsRecord& record); + friend bool HasValidDnsRecordAddress(const DomainName& domain); }; // Hashing functions to allow for using with absl::Hash<...>. diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc index e0aaea9b2d4..90506a6bf2d 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc @@ -4,14 +4,15 @@ #include "discovery/dnssd/public/dns_sd_instance_record.h" -#include "absl/strings/ascii.h" +#include <cctype> + #include "util/logging.h" namespace openscreen { namespace discovery { namespace { -bool IsValidUtf8(absl::string_view string) { +bool IsValidUtf8(const std::string& string) { for (size_t i = 0; i < string.size(); i++) { if (string[i] >> 5 == 0x06) { // 110xxxxx 10xxxxxx if (i + 1 >= string.size() || string[++i] >> 6 != 0x02) { @@ -34,7 +35,7 @@ bool IsValidUtf8(absl::string_view string) { return true; } -bool HasControlCharacters(absl::string_view string) { +bool HasControlCharacters(const std::string& string) { for (auto ch : string) { if ((ch >= 0x0 && ch <= 0x1F /* Ascii control characters */) || ch == 0x7F /* DEL character */) { @@ -55,6 +56,7 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id, std::move(service_id), std::move(domain_id), std::move(txt)) { + OSP_DCHECK(endpoint); if (endpoint.address.IsV4()) { address_v4_ = std::move(endpoint); } else if (endpoint.address.IsV6()) { @@ -74,6 +76,8 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id, std::move(service_id), std::move(domain_id), std::move(txt)) { + OSP_CHECK(ipv4_endpoint); + OSP_CHECK(ipv6_endpoint); OSP_CHECK(ipv4_endpoint.address.IsV4()); OSP_CHECK(ipv6_endpoint.address.IsV6()); @@ -94,18 +98,11 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id, OSP_DCHECK(IsDomainValid(domain_id_)); } -bool DnsSdInstanceRecord::operator==(const DnsSdInstanceRecord& other) const { - return instance_id_ == other.instance_id_ && - service_id_ == other.service_id_ && domain_id_ == other.domain_id_ && - address_v4_ == other.address_v4_ && address_v6_ == other.address_v6_ && - txt_ == other.txt_; -} - uint16_t DnsSdInstanceRecord::port() const { - if (address_v4_.has_value()) { - return address_v4_.value().port; - } else if (address_v6_.has_value()) { - return address_v6_.value().port; + if (address_v4_) { + return address_v4_.port; + } else if (address_v6_) { + return address_v6_.port; } else { OSP_NOTREACHED(); return 0; @@ -113,7 +110,7 @@ uint16_t DnsSdInstanceRecord::port() const { } // static -bool IsInstanceValid(absl::string_view instance) { +bool IsInstanceValid(const std::string& instance) { // According to RFC6763, Instance names must: // - Be encoded in Net-Unicode (which required UTF-8 formatting). // - NOT contain ASCII control characters @@ -124,7 +121,7 @@ bool IsInstanceValid(absl::string_view instance) { } // static -bool IsServiceValid(absl::string_view service) { +bool IsServiceValid(const std::string& service) { // According to RFC6763, the service name "consists of a pair of DNS labels". // "The first label of the pair is an underscore character followed by the // Service Name" and "The second label is either '_tcp' [...] or '_udp'". @@ -138,7 +135,7 @@ bool IsServiceValid(absl::string_view service) { return false; } - const absl::string_view protocol = service.substr(service.size() - 5); + const std::string protocol = service.substr(service.size() - 5); if (protocol != "._udp" && protocol != "._tcp") { return false; } @@ -156,9 +153,9 @@ bool IsServiceValid(absl::string_view service) { return false; } last_char_hyphen = true; - } else if (absl::ascii_isalpha(service[i])) { + } else if (std::isalpha(service[i])) { seen_letter = true; - } else if (!absl::ascii_isdigit(service[i])) { + } else if (!std::isdigit(service[i])) { return false; } } @@ -167,7 +164,7 @@ bool IsServiceValid(absl::string_view service) { } // static -bool IsDomainValid(absl::string_view domain) { +bool IsDomainValid(const std::string& domain) { // As RFC6763 Section 4.1.3 provides no validation requirements for the domain // section, the following validations are used: // - All labels must be no longer than 63 characters @@ -191,5 +188,32 @@ bool IsDomainValid(absl::string_view domain) { return !HasControlCharacters(domain) && IsValidUtf8(domain); } +bool operator<(const DnsSdInstanceRecord& lhs, const DnsSdInstanceRecord& rhs) { + int comp = lhs.instance_id_.compare(rhs.instance_id_); + if (comp != 0) { + return comp < 0; + } + + comp = lhs.service_id_.compare(rhs.service_id_); + if (comp != 0) { + return comp < 0; + } + + comp = lhs.domain_id_.compare(rhs.domain_id_); + if (comp != 0) { + return comp < 0; + } + + if (lhs.address_v4_ != rhs.address_v4_) { + return lhs.address_v4_ < rhs.address_v4_; + } + + if (lhs.address_v6_ != rhs.address_v6_) { + return lhs.address_v6_ < rhs.address_v6_; + } + + return lhs.txt_ < rhs.txt_; +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h index 5461165fdde..9dbb8ffa6ac 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h @@ -5,17 +5,15 @@ #ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_RECORD_H_ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_RECORD_H_ -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "discovery/dnssd/public/dns_sd_txt_record.h" #include "platform/base/ip_address.h" namespace openscreen { namespace discovery { -bool IsInstanceValid(absl::string_view instance); -bool IsServiceValid(absl::string_view service); -bool IsDomainValid(absl::string_view domain); +bool IsInstanceValid(const std::string& instance); +bool IsServiceValid(const std::string& service); +bool IsDomainValid(const std::string& domain); // Represents the data stored in DNS records of types SRV, TXT, A, and AAAA class DnsSdInstanceRecord { @@ -45,10 +43,10 @@ class DnsSdInstanceRecord { // Returns the domain id for this DNS-SD record. const std::string& domain_id() const { return domain_id_; } - // Returns the addess associated with this DNS-SD record. In any valid record, - // at least one will be set. - const absl::optional<IPEndpoint>& address_v4() const { return address_v4_; } - const absl::optional<IPEndpoint>& address_v6() const { return address_v6_; } + // Returns the address associated with this DNS-SD record. In any valid + // record, at least one will be set. + const IPEndpoint& address_v4() const { return address_v4_; } + const IPEndpoint& address_v6() const { return address_v6_; } // Returns the TXT record associated with this DNS-SD record const DnsSdTxtRecord& txt() const { return txt_; } @@ -56,12 +54,6 @@ class DnsSdInstanceRecord { // Returns the port associated with this instance record. uint16_t port() const; - bool operator==(const DnsSdInstanceRecord& other) const; - - inline bool operator!=(const DnsSdInstanceRecord& other) const { - return !(*this == other); - } - private: DnsSdInstanceRecord(std::string instance_id, std::string service_id, @@ -71,11 +63,41 @@ class DnsSdInstanceRecord { std::string instance_id_; std::string service_id_; std::string domain_id_; - absl::optional<IPEndpoint> address_v4_; - absl::optional<IPEndpoint> address_v6_; + IPEndpoint address_v4_; + IPEndpoint address_v6_; DnsSdTxtRecord txt_; + + friend bool operator<(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs); }; +bool operator<(const DnsSdInstanceRecord& lhs, const DnsSdInstanceRecord& rhs); + +inline bool operator>(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs) { + return rhs < lhs; +} + +inline bool operator<=(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs) { + return !(rhs > lhs); +} + +inline bool operator>=(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs) { + return !(rhs < lhs); +} + +inline bool operator==(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs) { + return lhs <= rhs && lhs >= rhs; +} + +inline bool operator!=(const DnsSdInstanceRecord& lhs, + const DnsSdInstanceRecord& rhs) { + return !(lhs == rhs); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h index 638a30f782d..d271b684c7c 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h @@ -5,7 +5,6 @@ #ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_PUBLISHER_H_ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_PUBLISHER_H_ -#include "absl/strings/string_view.h" #include "discovery/dnssd/public/dns_sd_instance_record.h" #include "platform/base/error.h" @@ -14,6 +13,20 @@ namespace discovery { class DnsSdPublisher { public: + class Client { + public: + virtual ~Client() = default; + + // Callback called when a record is successfully claimed and published via + // the Register() method. These records are expected to only differ in + // the DnsSdInstanceRecord::instance_id() field, or to be equal. This + // callback is purely for informational purposes and the caller is not + // required to act on it. + virtual void OnInstanceClaimed( + const DnsSdInstanceRecord& requested_record, + const DnsSdInstanceRecord& claimed_record) = 0; + }; + virtual ~DnsSdPublisher() = default; // Publishes the PTR, SRV, TXT, A, and AAAA records provided in the @@ -22,12 +35,21 @@ class DnsSdPublisher { // NOTE: Some embedders may return errors on other conditions (for instance, // android will return an error if the resulting TXT record has values not // encodable with UTF8). - virtual Error Register(const DnsSdInstanceRecord& record) = 0; + virtual Error Register(const DnsSdInstanceRecord& record, Client* client) = 0; + + // Updates the TXT, A, and AAAA records associated with the provided record, + // if any changes have occurred. The instance and domain names must match + // those of a previously published record. If either this is not true, no + // changes have occurred, or additional embedder-specific requirements have + // been violated, an error is returned. Else, Error::None is returned. + virtual Error UpdateRegistration(const DnsSdInstanceRecord& record) = 0; // Unpublishes any PTR, SRV, TXT, A, and AAAA records associated with this - // service id. If no such records are published, this operation will be a - // no-op. Returns the number of records which were removed. - virtual size_t DeregisterAll(absl::string_view service) = 0; + // service id, where the service id is the second part of the + // <instance>.<service>.<domain> domain name as described in RFC 6763. If no + // such records are published, this operation will be a no-op. Returns the + // number of records which were removed, or an error code on error. + virtual ErrorOr<int> DeregisterAll(const std::string& service) = 0; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h index 5a35f62d2d4..083d17d53cd 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h @@ -5,7 +5,6 @@ #ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_QUERIER_H_ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_QUERIER_H_ -#include "absl/strings/string_view.h" #include "discovery/dnssd/public/dns_sd_instance_record.h" namespace openscreen { @@ -39,12 +38,17 @@ class DnsSdQuerier { // NOTE: The provided service value is expected to be valid, as defined by the // IsServiceValid() method. // NOTE: The callback must be called on the TaskRunner thread. - virtual void StartQuery(absl::string_view service, Callback* cb) = 0; + virtual void StartQuery(const std::string& service, Callback* cb) = 0; // Stops an already running query. // NOTE: The provided service value is expected to be valid, as defined by the // IsServiceValid() method. - virtual void StopQuery(absl::string_view service, Callback* cb) = 0; + virtual void StopQuery(const std::string& service, Callback* cb) = 0; + + // Re-initializes the process of service discovery for the provided service + // id. All ongoing queries for this domain are restarted and any previously + // received query results are discarded. + virtual void ReinitializeQueries(const std::string& service) = 0; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h index 65cc83bef18..f0fe130bf88 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h @@ -5,8 +5,14 @@ #ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_SERVICE_H_ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_SERVICE_H_ +#include <functional> #include <memory> +#include "platform/base/error.h" +#include "platform/base/interface_info.h" +#include "platform/base/ip_address.h" +#include "util/serial_delete_ptr.h" + namespace openscreen { struct IPEndpoint; @@ -14,8 +20,10 @@ class TaskRunner; namespace discovery { +struct Config; class DnsSdPublisher; class DnsSdQuerier; +class ReportingClient; // This class provides a wrapper around DnsSdQuerier and DnsSdPublisher to // allow for an embedder-overridable factory method below. @@ -23,17 +31,13 @@ class DnsSdService { public: virtual ~DnsSdService() = default; - // Creates a new DnsSdService instance, to be owned by the caller. On failure, - // return nullptr. - static std::unique_ptr<DnsSdService> Create(TaskRunner* task_runner); - // Returns the DnsSdQuerier owned by this DnsSdService. If queries are not // supported, returns nullptr. - virtual DnsSdQuerier* Querier() = 0; + virtual DnsSdQuerier* GetQuerier() = 0; // Returns the DnsSdPublisher owned by this DnsSdService. If publishing is not // supported, returns nullptr. - virtual DnsSdPublisher* Publisher() = 0; + virtual DnsSdPublisher* GetPublisher() = 0; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc index 546bfb7386a..c981a119537 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc @@ -4,22 +4,53 @@ #include "discovery/dnssd/public/dns_sd_txt_record.h" -#include "absl/strings/ascii.h" +#include <cctype> namespace openscreen { namespace discovery { +// static +bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key, + const std::vector<uint8_t>& value) { + // The max length of any individual TXT record is 255 bytes. + if (key.size() + value.size() + 1 /* for equals */ > 255) { + return false; + } + + return IsKeyValid(key); +} + +// static +bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key, uint8_t value) { + return IsValidTxtValue(key, std::vector<uint8_t>{value}); +} + +// static +bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key, + const std::string& value) { + return IsValidTxtValue(key, std::vector<uint8_t>(value.begin(), value.end())); +} + Error DnsSdTxtRecord::SetValue(const std::string& key, - const absl::Span<const uint8_t>& value) { - if (!IsKeyValuePairValid(key, value)) { + std::vector<uint8_t> value) { + if (!IsValidTxtValue(key, value)) { return Error::Code::kParameterInvalid; } - key_value_txt_[key] = std::vector<uint8_t>(value.begin(), value.end()); + key_value_txt_[key] = std::move(value); ClearFlag(key); return Error::None(); } +Error DnsSdTxtRecord::SetValue(const std::string& key, uint8_t value) { + return SetValue(key, std::vector<uint8_t>{value}); +} + +Error DnsSdTxtRecord::SetValue(const std::string& key, + const std::string& value) { + return SetValue(key, std::vector<uint8_t>(value.begin(), value.end())); +} + Error DnsSdTxtRecord::SetFlag(const std::string& key, bool value) { if (!IsKeyValid(key)) { return Error::Code::kParameterInvalid; @@ -34,7 +65,7 @@ Error DnsSdTxtRecord::SetFlag(const std::string& key, bool value) { return Error::None(); } -ErrorOr<absl::Span<const uint8_t>> DnsSdTxtRecord::GetValue( +ErrorOr<DnsSdTxtRecord::ValueRef> DnsSdTxtRecord::GetValue( const std::string& key) const { if (!IsKeyValid(key)) { return Error::Code::kParameterInvalid; @@ -42,7 +73,7 @@ ErrorOr<absl::Span<const uint8_t>> DnsSdTxtRecord::GetValue( auto it = key_value_txt_.find(key); if (it != key_value_txt_.end()) { - return absl::Span<const uint8_t>(it->second.data(), it->second.size()); + return std::cref(it->second); } return Error::Code::kItemNotFound; @@ -74,7 +105,8 @@ Error DnsSdTxtRecord::ClearFlag(const std::string& key) { return Error::None(); } -bool DnsSdTxtRecord::IsKeyValid(const std::string& key) const { +// static +bool DnsSdTxtRecord::IsKeyValid(const std::string& key) { // The max length of any individual TXT record is 255 bytes. if (key.size() > 255) { return false; @@ -115,22 +147,23 @@ std::vector<std::vector<uint8_t>> DnsSdTxtRecord::GetData() const { return data; } -bool DnsSdTxtRecord::IsKeyValuePairValid( - const std::string& key, - const absl::Span<const uint8_t>& value) const { - // The max length of any individual TXT record is 255 bytes. - if (key.size() + value.size() + 1 /* for equals */ > 255) { - return false; - } - - return IsKeyValid(key); -} - bool DnsSdTxtRecord::CaseInsensitiveComparison::operator()( const std::string& lhs, const std::string& rhs) const { - return std::less<std::string>()(absl::AsciiStrToLower(lhs), - absl::AsciiStrToLower(rhs)); + if (lhs.size() != rhs.size()) { + return lhs < rhs; + } + + for (size_t i = 0; i < lhs.size(); i++) { + int lhs_char = tolower(lhs[i]); + int rhs_char = tolower(rhs[i]); + + if (lhs_char != rhs_char) { + return lhs_char < rhs_char; + } + } + + return false; } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h index 0fcfdf25878..b589d68bc35 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h @@ -5,14 +5,12 @@ #ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_TXT_RECORD_H_ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_TXT_RECORD_H_ +#include <functional> #include <map> #include <set> #include <string> #include <vector> -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "platform/base/error.h" namespace openscreen { @@ -20,14 +18,23 @@ namespace discovery { class DnsSdTxtRecord { public: + using ValueRef = std::reference_wrapper<const std::vector<uint8_t>>; + + // Returns whether the provided key value pair is valid for a TXT record. + static bool IsValidTxtValue(const std::string& key, + const std::vector<uint8_t>& value); + static bool IsValidTxtValue(const std::string& key, const std::string& value); + static bool IsValidTxtValue(const std::string& key, uint8_t value); + // Sets the value currently stored in this DNS-SD TXT record. Returns error // if the provided key is already set or if either the key or value is // invalid, and Error::None() otherwise. Keys are case-insensitive. Setting a // value or flag which was already set will overwrite the previous one, and // setting a value with a key which was previously associated with a flag // erases the flag's value and vice versa. - Error SetValue(const std::string& key, - const absl::Span<const uint8_t>& value); + Error SetValue(const std::string& key, std::vector<uint8_t> value); + Error SetValue(const std::string& key, uint8_t value); + Error SetValue(const std::string& key, const std::string& value); Error SetFlag(const std::string& key, bool value); // Reads the value associated with the provided key, or an error if the key @@ -36,7 +43,7 @@ class DnsSdTxtRecord { // NOTE: If GetValue is called on a key assigned to a flag, an ItemNotFound // error will be returned. If GetFlag is called on a key assigned to a value, // 'false' will be returned. - ErrorOr<absl::Span<const uint8_t>> GetValue(const std::string& key) const; + ErrorOr<ValueRef> GetValue(const std::string& key) const; ErrorOr<bool> GetFlag(const std::string& key) const; // Clears an existing TxtRecord value associated with the given key. If the @@ -58,24 +65,13 @@ class DnsSdTxtRecord { // quotes). std::vector<std::vector<uint8_t>> GetData() const; - inline bool operator==(const DnsSdTxtRecord& other) const { - return key_value_txt_ == other.key_value_txt_ && - boolean_txt_ == other.boolean_txt_; - } - - inline bool operator!=(const DnsSdTxtRecord& other) const { - return !(*this == other); - } - private: struct CaseInsensitiveComparison { bool operator()(const std::string& lhs, const std::string& rhs) const; }; // Validations for keys and (key, value) pairs. - bool IsKeyValid(const std::string& key) const; - bool IsKeyValuePairValid(const std::string& key, - const absl::Span<const uint8_t>& value) const; + static bool IsKeyValid(const std::string& key); // Set of (key, value) pairs associated with this TXT record. // NOTE: The same string name can only occur in one of key_value_txt_, @@ -88,8 +84,38 @@ class DnsSdTxtRecord { // NOTE: The same string name can only occur in one of key_value_txt_, // boolean_txt_. std::set<std::string, CaseInsensitiveComparison> boolean_txt_; + + friend bool operator<(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs); }; +inline bool operator<(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + if (lhs.boolean_txt_ != rhs.boolean_txt_) { + return lhs.boolean_txt_ < rhs.boolean_txt_; + } + + return lhs.key_value_txt_ < rhs.key_value_txt_; +} + +inline bool operator>(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + return rhs < lhs; +} + +inline bool operator<=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + return !(rhs > lhs); +} + +inline bool operator>=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + return !(rhs < lhs); +} + +inline bool operator==(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + return lhs <= rhs && lhs >= rhs; +} + +inline bool operator!=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) { + return !(lhs == rhs); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc index 4951bc19f42..f3239ca3d00 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc @@ -13,43 +13,64 @@ namespace dnssd { TEST(TxtRecordTest, TestCaseInsensitivity) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); EXPECT_TRUE(txt.GetValue("KEY").is_value()); EXPECT_TRUE(txt.SetFlag("KEY2", true).ok()); - EXPECT_TRUE(txt.GetFlag("key2").is_value()); + ASSERT_TRUE(txt.GetFlag("key2").is_value()); EXPECT_TRUE(txt.GetFlag("key2").value()); } TEST(TxtRecordTest, TestEmptyValue) { DnsSdTxtRecord txt; - EXPECT_TRUE(txt.SetValue("key", {}).ok()); - EXPECT_TRUE(txt.GetValue("key").is_value()); - EXPECT_EQ(txt.GetValue("key").value().size(), size_t{0}); + EXPECT_TRUE(txt.SetValue("key", std::vector<uint8_t>{}).ok()); + ASSERT_TRUE(txt.GetValue("key").is_value()); + EXPECT_EQ(txt.GetValue("key").value().get().size(), size_t{0}); + + EXPECT_TRUE(txt.SetValue("key2", "").ok()); + ASSERT_TRUE(txt.GetValue("key2").is_value()); + EXPECT_EQ(txt.GetValue("key2").value().get().size(), size_t{0}); } TEST(TxtRecordTest, TestSetAndGetValue) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); - EXPECT_TRUE(txt.GetValue("key").is_value()); - EXPECT_EQ(txt.GetValue("key").value().size(), size_t{3}); - EXPECT_EQ(txt.GetValue("key").value()[0], 'a'); - EXPECT_EQ(txt.GetValue("key").value()[1], 'b'); - EXPECT_EQ(txt.GetValue("key").value()[2], 'c'); - - uint8_t data2[]{'a', 'b'}; + ASSERT_TRUE(txt.GetValue("key").is_value()); + const std::vector<uint8_t>& value = txt.GetValue("key").value(); + ASSERT_EQ(value.size(), size_t{3}); + EXPECT_EQ(value[0], 'a'); + EXPECT_EQ(value[1], 'b'); + EXPECT_EQ(value[2], 'c'); + + std::vector<uint8_t> data2{'a', 'b'}; EXPECT_TRUE(txt.SetValue("key", data2).ok()); - EXPECT_TRUE(txt.GetValue("key").is_value()); - EXPECT_EQ(txt.GetValue("key").value().size(), size_t{2}); - EXPECT_EQ(txt.GetValue("key").value()[0], 'a'); - EXPECT_EQ(txt.GetValue("key").value()[1], 'b'); + ASSERT_TRUE(txt.GetValue("key").is_value()); + const std::vector<uint8_t>& value2 = txt.GetValue("key").value(); + EXPECT_EQ(value2.size(), size_t{2}); + EXPECT_EQ(value2[0], 'a'); + EXPECT_EQ(value2[1], 'b'); + + EXPECT_TRUE(txt.SetValue("key", "abc").ok()); + ASSERT_TRUE(txt.GetValue("key").is_value()); + const std::vector<uint8_t>& value3 = txt.GetValue("key").value(); + ASSERT_EQ(value.size(), size_t{3}); + EXPECT_EQ(value3[0], 'a'); + EXPECT_EQ(value3[1], 'b'); + EXPECT_EQ(value3[2], 'c'); + + EXPECT_TRUE(txt.SetValue("key", "ab").ok()); + ASSERT_TRUE(txt.GetValue("key").is_value()); + const std::vector<uint8_t>& value4 = txt.GetValue("key").value(); + EXPECT_EQ(value4.size(), size_t{2}); + EXPECT_EQ(value4[0], 'a'); + EXPECT_EQ(value4[1], 'b'); } TEST(TxtRecordTest, TestClearValue) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); txt.ClearValue("key"); @@ -59,11 +80,11 @@ TEST(TxtRecordTest, TestClearValue) { TEST(TxtRecordTest, TestSetAndGetFlag) { DnsSdTxtRecord txt; EXPECT_TRUE(txt.SetFlag("key", true).ok()); - EXPECT_TRUE(txt.GetFlag("key").is_value()); + ASSERT_TRUE(txt.GetFlag("key").is_value()); EXPECT_TRUE(txt.GetFlag("key").value()); EXPECT_TRUE(txt.SetFlag("key", false).ok()); - EXPECT_TRUE(txt.GetFlag("key").is_value()); + ASSERT_TRUE(txt.GetFlag("key").is_value()); EXPECT_FALSE(txt.GetFlag("key").value()); } @@ -72,12 +93,13 @@ TEST(TxtRecordTest, TestClearFlag) { EXPECT_TRUE(txt.SetFlag("key", true).ok()); txt.ClearFlag("key"); + ASSERT_TRUE(txt.GetFlag("key").is_value()); EXPECT_FALSE(txt.GetFlag("key").value()); } TEST(TxtRecordTest, TestGettingWrongRecordTypeFails) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); EXPECT_TRUE(txt.SetFlag("key2", true).ok()); EXPECT_FALSE(txt.GetValue("key2").is_value()); @@ -85,14 +107,14 @@ TEST(TxtRecordTest, TestGettingWrongRecordTypeFails) { TEST(TxtRecordTest, TestClearWrongRecordTypeFails) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); EXPECT_TRUE(txt.SetFlag("key2", true).ok()); } TEST(TxtRecordTest, TestGetDataWorks) { DnsSdTxtRecord txt; - uint8_t data[]{'a', 'b', 'c'}; + std::vector<uint8_t> data{'a', 'b', 'c'}; EXPECT_TRUE(txt.SetValue("key", data).ok()); EXPECT_TRUE(txt.SetFlag("bool", true).ok()); std::vector<std::vector<uint8_t>> results = txt.GetData(); diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS b/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS new file mode 100644 index 00000000000..62d4cc14c82 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS @@ -0,0 +1,7 @@ +# -*- Mode: Python; -*- + +include_rules = [ + '+discovery/dnssd/public', + '+discovery/dnssd/impl', + '+discovery/mdns', +] diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc index 462f949f853..61069284184 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc @@ -4,6 +4,8 @@ #include "discovery/dnssd/testing/fake_dns_record_factory.h" +#include <utility> + namespace openscreen { namespace discovery { @@ -19,38 +21,29 @@ MdnsRecord FakeDnsRecordFactory::CreateFullyPopulatedSrvRecord(uint16_t port) { } // static -const IPAddress FakeDnsRecordFactory::kV4Address = IPAddress(192, 168, 0, 0); - -// static -const IPAddress FakeDnsRecordFactory::kV6Address = - IPAddress(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16); - -// static -const IPEndpoint FakeDnsRecordFactory::kV4Endpoint{ - FakeDnsRecordFactory::kV4Address, FakeDnsRecordFactory::kPortNum}; +constexpr uint16_t FakeDnsRecordFactory::kPortNum; // static -const IPEndpoint FakeDnsRecordFactory::kV6Endpoint{ - FakeDnsRecordFactory::kV6Address, FakeDnsRecordFactory::kPortNum}; +const uint8_t FakeDnsRecordFactory::kV4AddressOctets[4] = {192, 168, 0, 0}; // static -const std::string FakeDnsRecordFactory::kInstanceName = "instance"; +const uint16_t FakeDnsRecordFactory::kV6AddressHextets[8] = { + 0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; // static -const std::string FakeDnsRecordFactory::kServiceName = "_srv-name._udp"; +const char FakeDnsRecordFactory::kInstanceName[] = "instance"; // static -const std::string FakeDnsRecordFactory::kServiceNameProtocolPart = "_udp"; +const char FakeDnsRecordFactory::kServiceName[] = "_srv-name._udp"; // static -const std::string FakeDnsRecordFactory::kServiceNameServicePart = "_srv-name"; +const char FakeDnsRecordFactory::kServiceNameProtocolPart[] = "_udp"; // static -const std::string FakeDnsRecordFactory::kDomainName = "local"; +const char FakeDnsRecordFactory::kServiceNameServicePart[] = "_srv-name"; // static -const InstanceKey FakeDnsRecordFactory::kKey = - InstanceKey(kInstanceName, kServiceName, kDomainName); +const char FakeDnsRecordFactory::kDomainName[] = "local"; } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h index f8bddcdb1ff..473ec68ebc1 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h @@ -5,11 +5,11 @@ #ifndef DISCOVERY_DNSSD_TESTING_FAKE_DNS_RECORD_FACTORY_H_ #define DISCOVERY_DNSSD_TESTING_FAKE_DNS_RECORD_FACTORY_H_ -#include <chrono> -#include <string> +#include <stdint.h> + +#include <chrono> // NOLINT #include "discovery/dnssd/impl/constants.h" -#include "discovery/dnssd/impl/instance_key.h" #include "discovery/mdns/mdns_records.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -20,16 +20,13 @@ namespace discovery { class FakeDnsRecordFactory { public: static constexpr uint16_t kPortNum = 80; - static const IPAddress kV4Address; - static const IPAddress kV6Address; - static const IPEndpoint kV4Endpoint; - static const IPEndpoint kV6Endpoint; - static const std::string kInstanceName; - static const std::string kServiceName; - static const std::string kServiceNameProtocolPart; - static const std::string kServiceNameServicePart; - static const std::string kDomainName; - static const InstanceKey kKey; + static const uint8_t kV4AddressOctets[4]; + static const uint16_t kV6AddressHextets[8]; + static const char kInstanceName[]; + static const char kServiceName[]; + static const char kServiceNameProtocolPart[]; + static const char kServiceNameServicePart[]; + static const char kDomainName[]; static MdnsRecord CreateFullyPopulatedSrvRecord(uint16_t port = kPortNum); }; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/DEPS b/chromium/third_party/openscreen/src/discovery/mdns/DEPS new file mode 100644 index 00000000000..309d03f4532 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/DEPS @@ -0,0 +1,5 @@ +# -*- Mode: Python; -*- + +include_rules = [ + '+discovery/mdns/public', +] diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin new file mode 100644 index 00000000000..24ae31b1c40 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin @@ -0,0 +1,20 @@ +0010 0000 0000 0060 0000 0000 4047 5637 +4780 f537 5627 6796 3656 40f5 4736 0750 +c6f6 3616 c600 00f2 0810 0000 0050 0052 +5047 5637 4723 90f5 3756 2767 9636 5623 +40f5 4736 0760 c6f6 3616 c623 0000 6000 +8000 0004 1080 4756 3737 1646 6647 a0f5 +3756 2767 9636 1646 560c a100 1208 1000 +0000 5000 3200 1000 2000 3030 e656 77b0 +f5e6 5677 3756 2767 9636 5640 f557 4607 +60c6 f636 16c6 4300 7047 5667 6637 3747 +a0f5 3756 2767 9636 5616 370c a100 0108 +1000 0000 5000 b140 4756 3747 9026 2756 +1646 d3e6 f677 b086 56c6 c6f6 e277 f627 +c646 7047 5637 4666 7647 a0f5 3756 2767 +1637 9636 560c a100 1008 ff00 0000 5000 +400c 8a10 100c 2d00 c108 1000 0000 5000 +0100 1000 2000 3000 4000 5000 6000 7000 +8070 4756 4666 7637 47b0 f537 5627 3637 +1667 9636 560c a100 c008 ff00 0000 5000 +200c b7 diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin new file mode 100644 index 00000000000..2fffbdf6bfc --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin @@ -0,0 +1,7 @@ +0010 0000 0030 0000 0000 0000 4047 5637 +4780 f537 5627 6796 3656 40f5 4736 0750 +c6f6 3616 c600 00ff 0010 5047 5637 4723 +90f5 3756 2767 9636 5623 40f5 4736 0760 +c6f6 3616 c623 0000 ff00 1050 4756 3747 +3390 f537 5627 6796 3656 3340 f557 4607 +60c6 f636 16c6 3300 00ff 0010 diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin new file mode 100644 index 00000000000..5792536a579 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin @@ -0,0 +1,10 @@ +0010 0000 0010 0000 0050 0000 4047 5637 +4780 f537 5627 6796 3656 40f5 4736 0750 +c6f6 3616 c600 00ff 08ff 0cc0 0012 0810 +0000 0050 0032 0010 0020 0030 30e6 5677 +b0f5 e656 7737 5627 6796 3656 40f5 5746 +0760 c6f6 3616 c643 000c c000 0108 1000 +0000 5000 1000 0cc0 0010 08ff 0000 0050 +0040 0c8a 1010 0cc0 00c1 0810 0000 0050 +0001 0010 0020 0030 0040 0050 0060 0070 +0080 0cc0 00c0 08ff 0000 0050 0020 0cc3 diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin new file mode 100644 index 00000000000..19a5727536b --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin @@ -0,0 +1,8 @@ +0010 0000 0000 0010 0000 0040 80f5 3756 +2767 9636 5640 f547 3607 50c6 f636 16c6 +0000 c008 ff00 0000 5000 7040 4756 3747 +0cc0 0cb2 0012 0810 0000 0050 0080 0010 +0020 0030 0cb2 0cb2 0001 0810 0000 0050 +0010 000c b200 1008 ff00 0000 5000 400c +8a10 100c b200 c108 1000 0000 5000 0100 +1000 2000 3000 4000 5000 6000 7000 80 diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h new file mode 100644 index 00000000000..4585ff65a18 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h @@ -0,0 +1,28 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_ +#define DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_ + +#include "discovery/mdns/mdns_records.h" + +namespace openscreen { +namespace discovery { + +class MdnsDomainConfirmedProvider { + public: + virtual ~MdnsDomainConfirmedProvider() = default; + + // Called once the probing phase has been completed, and a DomainName has + // been confirmed. The callee is expected to register records for the + // newly confirmed name in this callback. Note that the requested name and + // the confirmed name may differ if conflict resolution has occurred. + virtual void OnDomainFound(const DomainName& requested_name, + const DomainName& confirmed_name) = 0; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_ diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc new file mode 100644 index 00000000000..dc911043b0c --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc @@ -0,0 +1,113 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_probe.h" + +#include <utility> + +#include "discovery/mdns/mdns_random.h" +#include "discovery/mdns/mdns_sender.h" +#include "discovery/mdns/public/mdns_constants.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" + +namespace openscreen { +namespace discovery { + +MdnsProbe::MdnsProbe(DomainName target_name, IPAddress address) + : target_name_(std::move(target_name)), + address_(std::move(address)), + address_record_(CreateAddressRecord(target_name_, address_)) {} + +MdnsProbe::~MdnsProbe() = default; + +MdnsProbeImpl::Observer::~Observer() = default; + +MdnsProbeImpl::MdnsProbeImpl(MdnsSender* sender, + MdnsReceiver* receiver, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + Observer* observer, + DomainName target_name, + IPAddress address) + : MdnsProbe(std::move(target_name), std::move(address)), + random_delay_(random_delay), + task_runner_(task_runner), + now_function_(now_function), + alarm_(now_function_, task_runner_), + sender_(sender), + receiver_(receiver), + observer_(observer) { + OSP_DCHECK(sender_); + OSP_DCHECK(receiver_); + OSP_DCHECK(random_delay_); + OSP_DCHECK(task_runner_); + OSP_DCHECK(observer_); + + receiver_->AddResponseCallback(this); + alarm_.ScheduleFromNow([this]() { ProbeOnce(); }, + random_delay_->GetInitialProbeDelay()); +} + +MdnsProbeImpl::~MdnsProbeImpl() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + Stop(); +} + +void MdnsProbeImpl::ProbeOnce() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (successful_probe_queries_++ < kProbeIterationCountBeforeSuccess) { + // MdnsQuerier cannot be used, because probe queries cannot use the cache, + // so instead directly send the query through the MdnsSender. + MdnsMessage probe_query(CreateMessageId(), MessageType::Query); + MdnsQuestion probe_question(target_name(), DnsType::kANY, DnsClass::kIN, + ResponseType::kUnicast); + probe_query.AddQuestion(probe_question); + probe_query.AddAuthorityRecord(address_record()); + sender_->SendMulticast(probe_query); + + alarm_.ScheduleFromNow([this]() { ProbeOnce(); }, + kDelayBetweenProbeQueries); + } else { + Stop(); + observer_->OnProbeSuccess(this); + } +} + +void MdnsProbeImpl::Stop() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (is_running_) { + alarm_.Cancel(); + receiver_->RemoveResponseCallback(this); + is_running_ = false; + } +} + +void MdnsProbeImpl::Postpone(std::chrono::seconds delay) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + successful_probe_queries_ = 0; + alarm_.Cancel(); + alarm_.ScheduleFromNow([this]() { ProbeOnce(); }, + std::chrono::duration_cast<Clock::duration>(delay)); +} + +void MdnsProbeImpl::OnMessageReceived(const MdnsMessage& message) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + OSP_DCHECK(message.type() == MessageType::Response); + + for (const auto& record : message.answers()) { + if (record.name() == target_name()) { + Stop(); + observer_->OnProbeFailure(this); + } + } +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h new file mode 100644 index 00000000000..d918266d10d --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h @@ -0,0 +1,129 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_MDNS_MDNS_PROBE_H_ +#define DISCOVERY_MDNS_MDNS_PROBE_H_ + +#include <vector> + +#include "discovery/mdns/mdns_receiver.h" +#include "discovery/mdns/mdns_records.h" +#include "platform/api/time.h" +#include "platform/base/ip_address.h" +#include "util/alarm.h" + +namespace openscreen { + +class TaskRunner; + +namespace discovery { + +class MdnsQuerier; +class MdnsRandom; +class MdnsSender; + +// Implements the probing method as described in RFC 6762 section 8.1 to claim a +// provided domain name. In place of the MdnsRecord(s) that will be published, a +// 'fake' mDNS record of type A or AAAA will be generated from provided endpoint +// variable with TTL 2 seconds. 0 or 1 seconds are not used because these +// constants are used as part of goodbye records, so poorly written receivers +// may handle these cases in unexpected ways. Caching of probe queries is not +// supported for mDNS probes (else, in a probe which failed, invalid records +// would be cached). If for some reason this did occur, though, it should be a +// non-issue because the probe record will expire after 2 seconds. +// +// During probe query conflict resolution, these fake records will be compared +// with the records provided by another mDNS endpoint. As 2 different mDNS +// endpoints of the same service type cannot have the same endpoint, these +// fake mDNS records should never match the real or fake records provided by +// the other mDNS endpoint, so lexicographic comparison as described in RFC +// 6762 section 8.2.1 can proceed as described. +class MdnsProbe : public MdnsReceiver::ResponseClient { + public: + // The observer class is responsible for returning the result of an ongoing + // probe query to the caller. + class Observer { + public: + virtual ~Observer(); + + // Called once the probing phase has been completed successfully. |probe| is + // expected to be stopped at the time of this call. + virtual void OnProbeSuccess(MdnsProbe* probe) = 0; + + // Called once the probing phase fails. |probe| is expected to be stopped at + // the time of this call. + virtual void OnProbeFailure(MdnsProbe* probe) = 0; + }; + + MdnsProbe(DomainName target_name, IPAddress address); + virtual ~MdnsProbe(); + + // Postpones the current probe operation by |delay|, after which the probing + // process is re-initialized. + virtual void Postpone(std::chrono::seconds delay) = 0; + + const DomainName& target_name() const { return target_name_; } + const IPAddress& address() const { return address_; } + const MdnsRecord address_record() const { return address_record_; } + + private: + const DomainName target_name_; + const IPAddress address_; + const MdnsRecord address_record_; +}; + +class MdnsProbeImpl : public MdnsProbe { + public: + // |sender|, |receiver|, |random_delay|, |task_runner|, and |observer| must + // all persist for the duration of this object's lifetime. + MdnsProbeImpl(MdnsSender* sender, + MdnsReceiver* receiver, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + Observer* observer, + DomainName target_name, + IPAddress address); + MdnsProbeImpl(const MdnsProbeImpl& other) = delete; + MdnsProbeImpl(MdnsProbeImpl&& other) = delete; + ~MdnsProbeImpl() override; + + MdnsProbeImpl& operator=(const MdnsProbeImpl& other) = delete; + MdnsProbeImpl& operator=(MdnsProbeImpl&& other) = delete; + + // MdnsProbe overrides. + void Postpone(std::chrono::seconds delay) override; + + private: + friend class MdnsProbeTests; + + // Performs the probe query as described in the class-level comment. + void ProbeOnce(); + + // Stops this probe. + void Stop(); + + // MdnsReceiver::ResponseClient overrides. + void OnMessageReceived(const MdnsMessage& message) override; + + MdnsRandom* const random_delay_; + TaskRunner* const task_runner_; + ClockNowFunctionPtr now_function_; + + Alarm alarm_; + + // NOTE: Access to all below variables should only be done from the task + // runner thread. + MdnsSender* const sender_; + MdnsReceiver* const receiver_; + Observer* const observer_; + + int successful_probe_queries_ = 0; + bool is_running_ = true; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_MDNS_MDNS_PROBE_H_ diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc new file mode 100644 index 00000000000..7ac93c72900 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc @@ -0,0 +1,255 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_probe_manager.h" + +#include <string> +#include <utility> + +#include "discovery/mdns/mdns_sender.h" +#include "platform/api/task_runner.h" + +namespace openscreen { +namespace discovery { +namespace { + +// The timespan by which to delay subsequent mDNS Probe queries for the same +// domain name when a simultaneous query from another host is detected, as +// described in RFC 6762 section 8.2 +constexpr std::chrono::seconds kSimultaneousProbeDelay = + std::chrono::seconds(1); + +DomainName CreateRetryDomainName(const DomainName& name, int attempt) { + OSP_DCHECK(name.labels().size()); + std::vector<std::string> labels = name.labels(); + std::string& label = labels[0]; + std::string attempts_str = std::to_string(attempt); + if (label.size() + attempts_str.size() >= kMaxLabelLength) { + label = label.substr(0, kMaxLabelLength - attempts_str.size()); + } + label.append(attempts_str); + + return DomainName(std::move(labels)); +} + +} // namespace + +MdnsProbeManager::~MdnsProbeManager() = default; + +MdnsProbeManagerImpl::MdnsProbeManagerImpl(MdnsSender* sender, + MdnsReceiver* receiver, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function) + : sender_(sender), + receiver_(receiver), + random_delay_(random_delay), + task_runner_(task_runner), + now_function_(now_function) { + OSP_DCHECK(sender_); + OSP_DCHECK(receiver_); + OSP_DCHECK(task_runner_); + OSP_DCHECK(random_delay_); +} + +MdnsProbeManagerImpl::~MdnsProbeManagerImpl() = default; + +Error MdnsProbeManagerImpl::StartProbe(MdnsDomainConfirmedProvider* callback, + DomainName requested_name, + IPAddress address) { + // Check if |requested_name| is already being queried for. + if (FindOngoingProbe(requested_name) != ongoing_probes_.end()) { + return Error::Code::kOperationInProgress; + } + + // Check if |requested_name| is already claimed. + if (IsDomainClaimed(requested_name)) { + return Error::Code::kItemAlreadyExists; + } + + OSP_DVLOG << "Starting new mDNS Probe for domain '" + << requested_name.ToString() << "'"; + + // Begin a new probe. + auto probe = CreateProbe(requested_name, std::move(address)); + ongoing_probes_.emplace_back(std::move(probe), std::move(requested_name), + callback); + return Error::None(); +} + +Error MdnsProbeManagerImpl::StopProbe(const DomainName& requested_name) { + auto it = FindOngoingProbe(requested_name); + if (it == ongoing_probes_.end()) { + return Error::Code::kItemNotFound; + } + + ongoing_probes_.erase(it); + return Error::None(); +} + +bool MdnsProbeManagerImpl::IsDomainClaimed(const DomainName& domain) const { + return FindCompletedProbe(domain) != completed_probes_.end(); +} + +void MdnsProbeManagerImpl::RespondToProbeQuery(const MdnsMessage& message, + const IPEndpoint& src) { + OSP_DCHECK(!message.questions().empty()); + + const std::vector<MdnsQuestion>& questions = message.questions(); + MdnsMessage send_message(CreateMessageId(), MessageType::Response); + + // Iterate across all questions asked and all completed probes and add A or + // AAAA records associated with the endpoints for which the names match. + // |questions| is expected to be of size 1 and |completed_probes| should be + // small (generally size 1), so this should be fast. + for (const auto& question : questions) { + for (auto it = completed_probes_.begin(); it != completed_probes_.end(); + it++) { + if (question.name() == (*it)->target_name()) { + send_message.AddAnswer((*it)->address_record()); + break; + } + } + } + + if (!send_message.answers().empty()) { + sender_->SendMessage(send_message, src); + } else { + // If the name isn't already claimed, check to see if a probe is ongoing. If + // so, compare the address record for that probe with the one in the + // received message and resolve as specified in RFC 6762 section 8.2. + TiebreakSimultaneousProbes(message); + } +} + +void MdnsProbeManagerImpl::TiebreakSimultaneousProbes( + const MdnsMessage& message) { + OSP_DCHECK(!message.questions().empty()); + OSP_DCHECK(!message.authority_records().empty()); + + for (const auto& question : message.questions()) { + for (auto it = ongoing_probes_.begin(); it != ongoing_probes_.end(); it++) { + if (it->probe->target_name() == question.name()) { + // When a host is probing for a set of records with the same name, or a + // message is received containing multiple tiebreaker records answering + // a given probe question in the Question Section, the host's records + // and the tiebreaker records from the message are each sorted into + // order, and then compared pairwise, using the same comparison + // technique described above, until a difference is found. Because the + // probe object is guaranteed to only have the address record, only the + // lowest authority record is needed. + auto lowest_record_it = + std::min_element(message.authority_records().begin(), + message.authority_records().end()); + + // If this host finds that its own data is lexicographically later, it + // simply ignores the other host's probe. The other host will have + // receive this host's probe simultaneously, and will reject its own + // probe through this same calculation. + const MdnsRecord& probe_record = it->probe->address_record(); + if (probe_record > *lowest_record_it) { + break; + } + + // If the probe query is only of size one and the record received is + // equal to this record, then the received query is the same as what + // this probe is sending out. In this case, nothing needs to be done. + if (message.authority_records().size() == 1 && + !(probe_record < *lowest_record_it)) { + break; + } + + // At this point, one of the following must be true: + // - The query's lowest record is greater than this probe's record + // - The query's lowest record equals this probe's record but it also + // has additional records. + // In either case, the query must take priority over this probe. This + // host defers to the winning host by waiting one second, and then + // begins probing for this record again. See RFC 6762 section 8.2 for + // the logic behind waiting one second. + it->probe->Postpone(kSimultaneousProbeDelay); + break; + } + } + } +} + +void MdnsProbeManagerImpl::OnProbeSuccess(MdnsProbe* probe) { + auto it = FindOngoingProbe(probe); + if (it != ongoing_probes_.end()) { + DomainName target_name = it->probe->target_name(); + completed_probes_.push_back(std::move(it->probe)); + DomainName requested = std::move(it->requested_name); + MdnsDomainConfirmedProvider* callback = it->callback; + ongoing_probes_.erase(it); + callback->OnDomainFound(std::move(requested), std::move(target_name)); + } +} + +void MdnsProbeManagerImpl::OnProbeFailure(MdnsProbe* probe) { + auto ongoing_it = FindOngoingProbe(probe); + if (ongoing_it == ongoing_probes_.end()) { + // This means that the probe was canceled. + return; + } + + OSP_DVLOG << "Probe for domain '" + << CreateRetryDomainName(ongoing_it->requested_name, + ongoing_it->num_probes_failed) + .ToString() + << "' failed. Trying new domain..."; + + // Create a new probe with a modified domain name. + DomainName new_name = CreateRetryDomainName(ongoing_it->requested_name, + ++ongoing_it->num_probes_failed); + + // If this domain has already been claimed, skip ahead to knowing it's + // claimed. + auto completed_it = FindCompletedProbe(new_name); + if (completed_it != completed_probes_.end()) { + DomainName requested_name = std::move(ongoing_it->requested_name); + MdnsDomainConfirmedProvider* callback = ongoing_it->callback; + ongoing_probes_.erase(ongoing_it); + callback->OnDomainFound(requested_name, (*completed_it)->target_name()); + } else { + std::unique_ptr<MdnsProbe> new_probe = + CreateProbe(std::move(new_name), ongoing_it->probe->address()); + ongoing_it->probe = std::move(new_probe); + } +} + +std::vector<std::unique_ptr<MdnsProbe>>::const_iterator +MdnsProbeManagerImpl::FindCompletedProbe(const DomainName& name) const { + return std::find_if(completed_probes_.begin(), completed_probes_.end(), + [&name](const std::unique_ptr<MdnsProbe>& completed) { + return completed->target_name() == name; + }); +} + +std::vector<MdnsProbeManagerImpl::OngoingProbe>::iterator +MdnsProbeManagerImpl::FindOngoingProbe(const DomainName& name) { + return std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(), + [&name](const OngoingProbe& ongoing) { + return ongoing.requested_name == name; + }); +} + +std::vector<MdnsProbeManagerImpl::OngoingProbe>::iterator +MdnsProbeManagerImpl::FindOngoingProbe(MdnsProbe* probe) { + return std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(), + [&probe](const OngoingProbe& ongoing) { + return ongoing.probe.get() == probe; + }); +} + +MdnsProbeManagerImpl::OngoingProbe::OngoingProbe( + std::unique_ptr<MdnsProbe> probe, + DomainName name, + MdnsDomainConfirmedProvider* callback) + : probe(std::move(probe)), + requested_name(std::move(name)), + callback(callback) {} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h new file mode 100644 index 00000000000..6e8584ff7e9 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h @@ -0,0 +1,150 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_ +#define DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_ + +#include <memory> +#include <vector> + +#include "discovery/mdns/mdns_domain_confirmed_provider.h" +#include "discovery/mdns/mdns_probe.h" +#include "discovery/mdns/mdns_records.h" +#include "platform/base/error.h" +#include "platform/base/ip_address.h" + +namespace openscreen { + +class TaskRunner; + +namespace discovery { + +class MdnsQuerier; +class MdnsRandom; +class MdnsSender; + +// Interface for maintaining ownership of mDNS Domains. +class MdnsProbeManager { + public: + virtual ~MdnsProbeManager(); + + // Returns whether the provided domain name has been claimed as owned by this + // mDNS Probe Manager. + virtual bool IsDomainClaimed(const DomainName& domain) const = 0; + + // |message| is a message received from another host which contains a query + // from some domain. It is a considered a probe query for a specific domain if + // it contains a query for a specific domain which is answered by mDNS Records + // in the 'authority records' section of |message|. If a probe for the + // provided domain name is ongoing, an MdnsMessage is sent to the provided + // endpoint as described in RFC 6762 section 8.2 to allow for conflict + // resolution. If the requested name has already been claimed, a message to + // specify this will be sent as described in RFC 6762 section 8.1. The |src| + // argument is the address from which the message was originally sent, so that + // the response message may be sent as a unicast response. + virtual void RespondToProbeQuery(const MdnsMessage& message, + const IPEndpoint& src) = 0; +}; + +// This class is responsible for managing all ongoing probes for claiming domain +// names, as described in RFC 6762 Section 8.1's probing phase. If one such +// probe fails due to a conflict detection, this class will modify the domain +// name as described in RFC 6762 section 9 and re-initiate probing for the new +// name. +class MdnsProbeManagerImpl : public MdnsProbe::Observer, + public MdnsProbeManager { + public: + // |sender|, |receiver|, |random_delay|, and |task_runner|, must all persist + // for the duration of this object's lifetime. + MdnsProbeManagerImpl(MdnsSender* sender, + MdnsReceiver* receiver, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function); + MdnsProbeManagerImpl(const MdnsProbeManager& other) = delete; + MdnsProbeManagerImpl(MdnsProbeManager&& other) = delete; + ~MdnsProbeManagerImpl() override; + + MdnsProbeManagerImpl& operator=(const MdnsProbeManagerImpl& other) = delete; + MdnsProbeManagerImpl& operator=(MdnsProbeManagerImpl&& other) = delete; + + // Starts probing for a valid domain name based on the given one. This may + // only be called once per MdnsProbe instance. |observer| must persist until + // a valid domain is discovered and the observer's OnDomainFound method is + // called. + // NOTE: |address| is used to generate a 'fake' address record to use for the + // probe query. See MdnsProbe::PerformProbeIteration() for further details. + Error StartProbe(MdnsDomainConfirmedProvider* callback, + DomainName requested_name, + IPAddress address); + + // Stops probing for the requested domain name. + Error StopProbe(const DomainName& requested_name); + + // MdnsDomainOwnershipManager overrides. + bool IsDomainClaimed(const DomainName& domain) const override; + void RespondToProbeQuery(const MdnsMessage& message, + const IPEndpoint& src) override; + + private: + friend class TestMdnsProbeManager; + + // Resolves simultaneous probe queries as described in RFC 6762 section 8.2. + void TiebreakSimultaneousProbes(const MdnsMessage& message); + + virtual std::unique_ptr<MdnsProbe> CreateProbe(DomainName name, + IPAddress address) { + return std::make_unique<MdnsProbeImpl>(sender_, receiver_, random_delay_, + task_runner_, now_function_, this, + std::move(name), std::move(address)); + } + + // Owns an in-progress MdnsProbe. When the probe starts, an instance of this + // struct is created. Upon successful completion of the probe, this instance + // is deleted and the owned |probe| instance is moved to |completed_probes|. + // Upon failure, the instance is updated with a new MdnsProbe object and this + // process is repeated. + struct OngoingProbe { + OngoingProbe(std::unique_ptr<MdnsProbe> probe, + DomainName name, + MdnsDomainConfirmedProvider* callback); + + // NOTE: unique_ptr objects are used to avoid issues when the container + // holding this object is resized. + std::unique_ptr<MdnsProbe> probe; + DomainName requested_name; + MdnsDomainConfirmedProvider* callback; + int num_probes_failed = 0; + }; + + // MdnsProbe::Observer overrides. + void OnProbeSuccess(MdnsProbe* probe) override; + void OnProbeFailure(MdnsProbe* probe) override; + + // Helpers to find ongoing and completed probes. + std::vector<std::unique_ptr<MdnsProbe>>::const_iterator FindCompletedProbe( + const DomainName& name) const; + std::vector<OngoingProbe>::iterator FindOngoingProbe(const DomainName& name); + std::vector<OngoingProbe>::iterator FindOngoingProbe(MdnsProbe* probe); + + MdnsSender* const sender_; + MdnsReceiver* const receiver_; + MdnsRandom* const random_delay_; + TaskRunner* const task_runner_; + ClockNowFunctionPtr now_function_; + + // The set of all probes which have completed successfully. This set is + // expected to remain small. unique_ptrs are used for storing the probes to + // avoid issues when the vector is resized. + std::vector<std::unique_ptr<MdnsProbe>> completed_probes_; + + // The set of all currently ongoing probes. This set is expected to remain + // small. + std::vector<OngoingProbe> ongoing_probes_; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_ diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc new file mode 100644 index 00000000000..1420e550c23 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc @@ -0,0 +1,355 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_probe_manager.h" + +#include <utility> + +#include "discovery/mdns/mdns_probe.h" +#include "discovery/mdns/mdns_querier.h" +#include "discovery/mdns/mdns_random.h" +#include "discovery/mdns/mdns_receiver.h" +#include "discovery/mdns/mdns_sender.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "platform/test/fake_udp_socket.h" + +using testing::_; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +namespace openscreen { +namespace discovery { + +class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider { + public: + MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&)); +}; + +class MockMdnsSender : public MdnsSender { + public: + explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} + + MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); + MOCK_METHOD2(SendMessage, + Error(const MdnsMessage& message, const IPEndpoint& endpoint)); +}; + +class MockMdnsProbe : public MdnsProbe { + public: + MockMdnsProbe(DomainName target_name, IPAddress address) + : MdnsProbe(std::move(target_name), std::move(address)) {} + + MOCK_METHOD1(Postpone, void(std::chrono::seconds)); + MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&)); +}; + +class TestMdnsProbeManager : public MdnsProbeManagerImpl { + public: + using MdnsProbeManagerImpl::MdnsProbeManagerImpl; + + using MdnsProbeManagerImpl::OnProbeFailure; + using MdnsProbeManagerImpl::OnProbeSuccess; + + std::unique_ptr<MdnsProbe> CreateProbe(DomainName name, + IPAddress address) override { + return std::make_unique<StrictMock<MockMdnsProbe>>(std::move(name), + std::move(address)); + } + + StrictMock<MockMdnsProbe>* GetOngoingMockProbeByTarget( + const DomainName& target) { + const auto it = + std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(), + [&target](const OngoingProbe& ongoing) { + return ongoing.probe->target_name() == target; + }); + if (it != ongoing_probes_.end()) { + return static_cast<StrictMock<MockMdnsProbe>*>(it->probe.get()); + } + return nullptr; + } + + StrictMock<MockMdnsProbe>* GetCompletedMockProbe(const DomainName& target) { + const auto it = FindCompletedProbe(target); + if (it != completed_probes_.end()) { + return static_cast<StrictMock<MockMdnsProbe>*>(it->get()); + } + return nullptr; + } + + bool HasOngoingProbe(const DomainName& target) { + return GetOngoingMockProbeByTarget(target) != nullptr; + } + + bool HasCompletedProbe(const DomainName& target) { + return GetCompletedMockProbe(target) != nullptr; + } + + size_t GetOngoingProbeCount() { return ongoing_probes_.size(); } + + size_t GetCompletedProbeCount() { return completed_probes_.size(); } +}; + +class MdnsProbeManagerTests : public testing::Test { + public: + MdnsProbeManagerTests() + : clock_(Clock::now()), + task_runner_(&clock_), + socket_(&task_runner_), + sender_(&socket_), + manager_(&sender_, + &receiver_, + &random_, + &task_runner_, + FakeClock::now) { + ExpectProbeStopped(name_); + ExpectProbeStopped(name2_); + ExpectProbeStopped(name_retry_); + } + + protected: + MdnsMessage CreateProbeQueryMessage(DomainName domain, + const IPAddress& address) { + MdnsMessage message(CreateMessageId(), MessageType::Query); + MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY, + ResponseType::kUnicast); + MdnsRecord record = CreateAddressRecord(std::move(domain), address); + message.AddQuestion(std::move(question)); + message.AddAuthorityRecord(std::move(record)); + return message; + } + + void ExpectProbeStopped(const DomainName& name) { + EXPECT_FALSE(manager_.HasOngoingProbe(name)); + EXPECT_FALSE(manager_.HasCompletedProbe(name)); + EXPECT_FALSE(manager_.IsDomainClaimed(name)); + } + + StrictMock<MockMdnsProbe>* ExpectProbeOngoing(const DomainName& name) { + // Get around limitations of using an assert in a function with a return + // value. + auto validate = [this, &name]() { + ASSERT_TRUE(manager_.HasOngoingProbe(name)); + EXPECT_FALSE(manager_.HasCompletedProbe(name)); + EXPECT_FALSE(manager_.IsDomainClaimed(name)); + }; + validate(); + + return manager_.GetOngoingMockProbeByTarget(name); + } + + StrictMock<MockMdnsProbe>* ExpectProbeCompleted(const DomainName& name) { + // Get around limitations of using an assert in a function with a return + // value. + auto validate = [this, &name]() { + EXPECT_FALSE(manager_.HasOngoingProbe(name)); + ASSERT_TRUE(manager_.HasCompletedProbe(name)); + EXPECT_TRUE(manager_.IsDomainClaimed(name)); + }; + validate(); + + return manager_.GetCompletedMockProbe(name); + } + + StrictMock<MockMdnsProbe>* SetUpCompletedProbe(const DomainName& name, + const IPAddress& address) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok()); + EXPECT_CALL(callback_, OnDomainFound(name, name)); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name); + manager_.OnProbeSuccess(ongoing_probe); + ExpectProbeCompleted(name); + testing::Mock::VerifyAndClearExpectations(ongoing_probe); + + return ongoing_probe; + } + + FakeClock clock_; + FakeTaskRunner task_runner_; + FakeUdpSocket socket_; + StrictMock<MockMdnsSender> sender_; + MdnsReceiver receiver_; + MdnsRandom random_; + StrictMock<TestMdnsProbeManager> manager_; + MockDomainConfirmedProvider callback_; + + const DomainName name_{"test", "_googlecast", "_tcp", "local"}; + const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"}; + const DomainName name2_{"test2", "_googlecast", "_tcp", "local"}; + + // When used to create address records A, B, C, A > B because comparison of + // the rdata in each results in the comparison of endpoints, for which + // address_b_ < address_a_. A < C because A is DnsType kA with value 1 and + // C is DnsType kAAAA with value 28. + const IPAddress address_a_{192, 168, 0, 0}; + const IPAddress address_b_{190, 160, 0, 0}; + const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708, + 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; + const IPEndpoint endpoint_{{192, 168, 0, 0}, 80}; +}; + +TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + ExpectProbeOngoing(name_); + EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); + + EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + + EXPECT_CALL(callback_, OnDomainFound(name_, name_)); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + manager_.OnProbeSuccess(ongoing_probe); + EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); + testing::Mock::VerifyAndClearExpectations(ongoing_probe); + + EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + + StrictMock<MockMdnsProbe>* completed_probe = ExpectProbeCompleted(name_); + EXPECT_EQ(ongoing_probe, completed_probe); + EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); +} + +TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) { + EXPECT_FALSE(manager_.StopProbe(name_).ok()); + ExpectProbeStopped(name_); + + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + ExpectProbeOngoing(name_); + + EXPECT_TRUE(manager_.StopProbe(name_).ok()); + ExpectProbeStopped(name_); + + SetUpCompletedProbe(name_, address_a_); + + EXPECT_FALSE(manager_.StopProbe(name_).ok()); + ExpectProbeCompleted(name_); +} + +TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) { + const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_); + manager_.RespondToProbeQuery(query, endpoint_); +} + +TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) { + SetUpCompletedProbe(name_, address_a_); + + const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_); + EXPECT_CALL(sender_, SendMessage(_, endpoint_)) + .WillOnce([this](const MdnsMessage& message, + const IPEndpoint& endpoint) -> Error { + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA); + EXPECT_EQ(message.answers()[0].name(), this->name_); + return Error::None(); + }); + manager_.RespondToProbeQuery(query, endpoint_); +} + +TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + + // If the probe message received matches the currently running probe, do + // nothing. + MdnsMessage query = CreateProbeQueryMessage(name_, address_a_); + manager_.RespondToProbeQuery(query, endpoint_); + + // If the probe message received is less than the ongoing probe, ignore the + // incoming probe. + query = CreateProbeQueryMessage(name_, address_b_); + manager_.RespondToProbeQuery(query, endpoint_); + + // If the probe message received is greater than the ongoing probe, postpone + // the currently running probe. + query = CreateProbeQueryMessage(name_, address_c_); + EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); + manager_.RespondToProbeQuery(query, endpoint_); +} + +TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + + // For the below tests, note that if records A, B, C are generated from + // addresses |address_a_|, |address_b_|, and |address_c_| respectively, + // then B < A < C. + // + // If the received records have one record less than the tested record, they + // are sorted and the lowest record is compared. + MdnsMessage query = CreateProbeQueryMessage(name_, address_b_); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); + manager_.RespondToProbeQuery(query, endpoint_); + + query = CreateProbeQueryMessage(name_, address_c_); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_)); + manager_.RespondToProbeQuery(query, endpoint_); + + query = CreateProbeQueryMessage(name_, address_a_); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_)); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); + manager_.RespondToProbeQuery(query, endpoint_); + + // If the probe message received has the same first record as what's being + // compared and the query has more records, the query wins. + query = CreateProbeQueryMessage(name_, address_a_); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); + EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); + manager_.RespondToProbeQuery(query, endpoint_); + testing::Mock::VerifyAndClearExpectations(ongoing_probe); + + query = CreateProbeQueryMessage(name_, address_c_); + query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_)); + EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); + manager_.RespondToProbeQuery(query, endpoint_); +} + +TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + EXPECT_TRUE(manager_.StopProbe(name_).ok()); + manager_.OnProbeSuccess(ongoing_probe); +} + +TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + EXPECT_TRUE(manager_.StopProbe(name_).ok()); + manager_.OnProbeFailure(ongoing_probe); +} + +TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) { + // This test first starts a probe with domain |name_retry_| so that when + // probe with domain |name_| fails, the newly generated domain with equal + // |name_retry_|. + StrictMock<MockMdnsProbe>* ongoing_probe = + SetUpCompletedProbe(name_retry_, address_a_); + + // Because |name_retry_| has already succeeded, the retry logic should skip + // over re-querying for |name_retry_| and jump right to success. + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + ongoing_probe = ExpectProbeOngoing(name_); + EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_)); + manager_.OnProbeFailure(ongoing_probe); + ExpectProbeStopped(name_); + ExpectProbeCompleted(name_retry_); +} + +TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) { + EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); + StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_); + manager_.OnProbeFailure(ongoing_probe); + ExpectProbeStopped(name_); + ongoing_probe = ExpectProbeOngoing(name_retry_); + EXPECT_EQ(ongoing_probe->target_name(), name_retry_); + + EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_)); + manager_.OnProbeSuccess(ongoing_probe); + ExpectProbeCompleted(name_retry_); + ExpectProbeStopped(name_); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc new file mode 100644 index 00000000000..d69e45b4b09 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc @@ -0,0 +1,128 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "discovery/mdns/mdns_probe.h" + +#include <memory> +#include <utility> + +#include "discovery/mdns/mdns_probe_manager.h" +#include "discovery/mdns/mdns_querier.h" +#include "discovery/mdns/mdns_random.h" +#include "discovery/mdns/mdns_receiver.h" +#include "discovery/mdns/mdns_sender.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "platform/test/fake_udp_socket.h" + +using testing::_; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +namespace openscreen { +namespace discovery { + +class MockMdnsSender : public MdnsSender { + public: + explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} + MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); + MOCK_METHOD2(SendMessage, + Error(const MdnsMessage& message, const IPEndpoint& endpoint)); +}; + +class MockObserver : public MdnsProbeImpl::Observer { + public: + MOCK_METHOD1(OnProbeSuccess, void(MdnsProbe*)); + MOCK_METHOD1(OnProbeFailure, void(MdnsProbe*)); +}; + +class MdnsProbeTests : public testing::Test { + public: + MdnsProbeTests() + : clock_(Clock::now()), + task_runner_(&clock_), + socket_(&task_runner_), + sender_(&socket_) { + EXPECT_EQ(task_runner_.delayed_task_count(), 0); + probe_ = CreateProbe(); + EXPECT_EQ(task_runner_.delayed_task_count(), 1); + } + + protected: + std::unique_ptr<MdnsProbeImpl> CreateProbe() { + return std::make_unique<MdnsProbeImpl>(&sender_, &receiver_, &random_, + &task_runner_, FakeClock::now, + &observer_, name_, address_v4_); + } + + MdnsMessage CreateMessage(const DomainName& domain) { + MdnsMessage message(0, MessageType::Response); + SrvRecordRdata rdata(0, 0, 80, domain); + MdnsRecord record(std::move(domain), DnsType::kSRV, DnsClass::kIN, + RecordType::kUnique, std::chrono::seconds(1), + std::move(rdata)); + message.AddAnswer(record); + return message; + } + + void OnMessageReceived(const MdnsMessage& message) { + probe_->OnMessageReceived(message); + } + + FakeClock clock_; + FakeTaskRunner task_runner_; + FakeUdpSocket socket_; + StrictMock<MockMdnsSender> sender_; + MdnsReceiver receiver_; + MdnsRandom random_; + StrictMock<MockObserver> observer_; + + std::unique_ptr<MdnsProbeImpl> probe_; + + const DomainName name_{"test", "_googlecast", "_tcp", "local"}; + const DomainName name2_{"test2", "_googlecast", "_tcp", "local"}; + + const IPAddress address_v4_{192, 168, 0, 0}; + const IPEndpoint endpoint_v4_{address_v4_, 80}; +}; + +TEST_F(MdnsProbeTests, TestNoCancelationFlow) { + EXPECT_CALL(sender_, SendMulticast(_)); + clock_.Advance(kDelayBetweenProbeQueries); + EXPECT_EQ(task_runner_.delayed_task_count(), 1); + testing::Mock::VerifyAndClearExpectations(&sender_); + + EXPECT_CALL(sender_, SendMulticast(_)); + clock_.Advance(kDelayBetweenProbeQueries); + EXPECT_EQ(task_runner_.delayed_task_count(), 1); + testing::Mock::VerifyAndClearExpectations(&sender_); + + EXPECT_CALL(sender_, SendMulticast(_)); + clock_.Advance(kDelayBetweenProbeQueries); + EXPECT_EQ(task_runner_.delayed_task_count(), 1); + testing::Mock::VerifyAndClearExpectations(&sender_); + + EXPECT_CALL(observer_, OnProbeSuccess(probe_.get())).Times(1); + clock_.Advance(kDelayBetweenProbeQueries); + EXPECT_EQ(task_runner_.delayed_task_count(), 0); +} + +TEST_F(MdnsProbeTests, CancelationWhenMatchingMessageReceived) { + EXPECT_CALL(observer_, OnProbeFailure(probe_.get())).Times(1); + OnMessageReceived(CreateMessage(name_)); +} + +TEST_F(MdnsProbeTests, TestNoCancelationOnUnrelatedMessages) { + OnMessageReceived(CreateMessage(name2_)); + + EXPECT_CALL(sender_, SendMulticast(_)); + clock_.Advance(kDelayBetweenProbeQueries); + EXPECT_EQ(task_runner_.delayed_task_count(), 1); + testing::Mock::VerifyAndClearExpectations(&sender_); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc index e1ebde80f19..51aa30e7dfe 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc @@ -4,61 +4,375 @@ #include "discovery/mdns/mdns_publisher.h" +#include <chrono> +#include <cmath> + +#include "discovery/common/config.h" +#include "discovery/mdns/mdns_probe_manager.h" +#include "discovery/mdns/mdns_records.h" #include "discovery/mdns/mdns_sender.h" #include "platform/api/task_runner.h" +#include "platform/base/trivial_clock_traits.h" namespace openscreen { namespace discovery { +namespace { + +// Minimum delay between announcements of a given record in seconds. +constexpr std::chrono::seconds kMinAnnounceDelay{1}; + +// Intervals between successive announcements must increase by at least a +// factor of 2. +constexpr int kIntervalIncreaseFactor = 2; + +// TTL for a goodbye record in seconds. This constant is called out in RFC 6762 +// section 10.1. +constexpr std::chrono::seconds kGoodbyeTtl{0}; + +// Timespan between sending batches of announcement and goodbye records, in +// microseconds. +constexpr Clock::duration kDelayBetweenBatchedRecords = + std::chrono::milliseconds(20); + +inline MdnsRecord CreateGoodbyeRecord(const MdnsRecord& record) { + if (record.ttl() == kGoodbyeTtl) { + return record; + } + return MdnsRecord(record.name(), record.dns_type(), record.dns_class(), + record.record_type(), kGoodbyeTtl, record.rdata()); +} + +inline void ValidateRecord(const MdnsRecord& record) { + OSP_DCHECK(record.dns_type() != DnsType::kANY); + OSP_DCHECK(record.dns_class() != DnsClass::kANY); +} + +} // namespace -MdnsPublisher::MdnsPublisher(MdnsQuerier* querier, - MdnsSender* sender, - platform::TaskRunner* task_runner, - MdnsRandom* random_delay) - : querier_(querier), - sender_(sender), +MdnsPublisher::MdnsPublisher(MdnsSender* sender, + MdnsProbeManager* ownership_manager, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + const Config& config) + : sender_(sender), + ownership_manager_(ownership_manager), task_runner_(task_runner), - random_delay_(random_delay) { - OSP_DCHECK(querier_); + now_function_(now_function), + max_announcement_attempts_(config.new_record_announcement_count) { + OSP_DCHECK(ownership_manager_); OSP_DCHECK(sender_); OSP_DCHECK(task_runner_); - OSP_DCHECK(random_delay_); + OSP_DCHECK_GE(max_announcement_attempts_, 0); } -MdnsPublisher::~MdnsPublisher() = default; +MdnsPublisher::~MdnsPublisher() { + if (batch_records_alarm_.has_value()) { + batch_records_alarm_.value().Cancel(); + ProcessRecordQueue(); + } +} Error MdnsPublisher::RegisterRecord(const MdnsRecord& record) { - // TODO(rwkeane): Implement this method. + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (record.dns_type() == DnsType::kNSEC) { + return Error::Code::kParameterInvalid; + } + ValidateRecord(record); + + if (!IsRecordNameClaimed(record)) { + return Error::Code::kParameterInvalid; + } + + const DomainName& name = record.name(); + auto it = records_.emplace(name, std::vector<RecordAnnouncerPtr>{}).first; + for (const RecordAnnouncerPtr& publisher : it->second) { + if (publisher->record() == record) { + return Error::Code::kItemAlreadyExists; + } + } + + OSP_DVLOG << "Registering record of type '" << record.dns_type() << "'"; + + it->second.push_back(CreateAnnouncer(record)); + return Error::None(); } +Error MdnsPublisher::UnregisterRecord(const MdnsRecord& record) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (record.dns_type() == DnsType::kNSEC) { + return Error::Code::kParameterInvalid; + } + ValidateRecord(record); + + OSP_DVLOG << "Unregistering record of type '" << record.dns_type() << "'"; + + return RemoveRecord(record, true); +} + Error MdnsPublisher::UpdateRegisteredRecord(const MdnsRecord& old_record, const MdnsRecord& new_record) { - // TODO(rwkeane): Implement this method. - return Error::None(); -} + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); -Error MdnsPublisher::UnregisterRecord(const MdnsRecord& record) { - // TODO(rwkeane): Implement this method. - return Error::None(); + if (old_record.dns_type() == DnsType::kNSEC) { + return Error::Code::kParameterInvalid; + } + + if (old_record.dns_type() == DnsType::kPTR) { + return Error::Code::kParameterInvalid; + } + + // Check that the old record and new record are compatible. + if (old_record.name() != new_record.name() || + old_record.dns_type() != new_record.dns_type() || + old_record.dns_class() != new_record.dns_class() || + old_record.record_type() != new_record.record_type()) { + return Error::Code::kParameterInvalid; + } + + OSP_DVLOG << "Updating record of type '" << new_record.dns_type() << "'"; + + // Remove the old record. Per RFC 6762 section 8.4, a goodbye message will not + // be sent, as all records which can be removed here are unique records, which + // will be overwritten during the announcement phase when the updated record + // is re-registered due to the cache-flush-bit's presence. + const Error remove_result = RemoveRecord(old_record, false); + if (!remove_result.ok()) { + return remove_result; + } + + // Register the new record. + return RegisterRecord(new_record); } -bool MdnsPublisher::IsExclusiveOwner(const DomainName& name) { - // TODO(rwkeane): Implement this method. - return false; +size_t MdnsPublisher::GetRecordCount() const { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + size_t count = 0; + for (const auto& pair : records_) { + count += pair.second.size(); + } + + return count; } bool MdnsPublisher::HasRecords(const DomainName& name, DnsType type, DnsClass clazz) { - // TODO(rwkeane): Implement this method. - return false; + return !GetRecords(name, type, clazz).empty(); } + std::vector<MdnsRecord::ConstRef> MdnsPublisher::GetRecords( const DomainName& name, DnsType type, DnsClass clazz) { - // TODO(rwkeane): Implement this method. - return std::vector<MdnsRecord::ConstRef>(); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + std::vector<MdnsRecord::ConstRef> records; + auto it = records_.find(name); + if (it != records_.end()) { + for (const RecordAnnouncerPtr& announcer : it->second) { + OSP_DCHECK(announcer.get()); + const DnsType record_dns_type = announcer->record().dns_type(); + const DnsClass record_dns_class = announcer->record().dns_class(); + if ((type == DnsType::kANY || type == record_dns_type) && + (clazz == DnsClass::kANY || clazz == record_dns_class)) { + records.push_back(announcer->record()); + } + } + } + + return records; +} + +std::vector<MdnsRecord::ConstRef> MdnsPublisher::GetPtrRecords(DnsClass clazz) { + std::vector<MdnsRecord::ConstRef> records; + + // There should be few records associated with any given domain name, so it is + // simpler and less error prone to iterate across all records than to check + // the domain name against format '[^.]+\.(_tcp)|(_udp)\..*'' + for (auto it = records_.begin(); it != records_.end(); it++) { + for (const RecordAnnouncerPtr& announcer : it->second) { + OSP_DCHECK(announcer.get()); + const DnsType record_dns_type = announcer->record().dns_type(); + if (record_dns_type != DnsType::kPTR) { + continue; + } + + const DnsClass record_dns_class = announcer->record().dns_class(); + if ((clazz == DnsClass::kANY || clazz == record_dns_class)) { + records.push_back(announcer->record()); + } + } + } + + return records; +} + +Error MdnsPublisher::RemoveRecord(const MdnsRecord& record, + bool should_announce_deletion) { + const DomainName& name = record.name(); + + // Check for the domain and fail if it's not found. + const auto it = records_.find(name); + if (it == records_.end()) { + return Error::Code::kItemNotFound; + } + + // Check for the record to be removed. + const auto records_it = + std::find_if(it->second.begin(), it->second.end(), + [&record](const RecordAnnouncerPtr& publisher) { + return publisher->record() == record; + }); + if (records_it == it->second.end()) { + return Error::Code::kItemNotFound; + } + if (!should_announce_deletion) { + (*records_it)->DisableGoodbyeMessageTransmission(); + } + + it->second.erase(records_it); + if (it->second.empty()) { + records_.erase(it); + } + + return Error::None(); +} + +bool MdnsPublisher::IsRecordNameClaimed(const MdnsRecord& record) const { + const DomainName& name = + record.dns_type() == DnsType::kPTR + ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain() + : record.name(); + return ownership_manager_->IsDomainClaimed(name); +} + +MdnsPublisher::RecordAnnouncer::RecordAnnouncer( + MdnsRecord record, + MdnsPublisher* publisher, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + int target_announcement_attempts) + : publisher_(publisher), + task_runner_(task_runner), + now_function_(now_function), + record_(std::move(record)), + alarm_(now_function_, task_runner_), + target_announcement_attempts_(target_announcement_attempts) { + OSP_DCHECK(publisher_); + OSP_DCHECK(task_runner_); + OSP_DCHECK(record_.ttl() != Clock::duration::zero()); + + QueueAnnouncement(); +} + +MdnsPublisher::RecordAnnouncer::~RecordAnnouncer() { + alarm_.Cancel(); + if (should_send_goodbye_message_) { + QueueGoodbye(); + } +} + +void MdnsPublisher::RecordAnnouncer::QueueGoodbye() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + publisher_->QueueRecord(CreateGoodbyeRecord(record_)); +} + +void MdnsPublisher::RecordAnnouncer::QueueAnnouncement() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (attempts_ >= target_announcement_attempts_) { + return; + } + + publisher_->QueueRecord(record_); + + const Clock::duration new_delay = GetNextAnnounceDelay(); + attempts_++; + alarm_.ScheduleFromNow([this]() { QueueAnnouncement(); }, new_delay); +} + +void MdnsPublisher::QueueRecord(MdnsRecord record) { + if (!batch_records_alarm_.has_value()) { + OSP_DCHECK(records_to_send_.empty()); + batch_records_alarm_.emplace(now_function_, task_runner_); + batch_records_alarm_.value().ScheduleFromNow( + [this]() { ProcessRecordQueue(); }, kDelayBetweenBatchedRecords); + } + + // Check that we aren't announcing and goodbye'ing a record in the same batch. + // We expect to be sending no more than 5 records at a time, so don't worry + // about iterating across this vector for each insert. + auto goodbye = CreateGoodbyeRecord(record); + auto existing_record_it = + std::find_if(records_to_send_.begin(), records_to_send_.end(), + [&goodbye](const MdnsRecord& record) { + return goodbye == CreateGoodbyeRecord(record); + }); + + // If we didn't find it, simply add it to the queue. Else, only send the + // goodbye record. + if (existing_record_it == records_to_send_.end()) { + records_to_send_.push_back(std::move(record)); + } else if (*existing_record_it == goodbye) { + // This means that the goodbye record is already queued to be sent. This + // means that there is no reason to also announce it, so exit early. + return; + } else if (record == goodbye) { + // This means that we are sending a goodbye record right as it would also + // be announced. Skip the announcement since the record is being + // unregistered. + *existing_record_it = std::move(record); + } else if (record == *existing_record_it) { + // This case shouldn't happen, but there is no work to do if it does. Log + // to surface that something weird is going on. + OSP_LOG_INFO << "Same record being announced multiple times."; + } else { + // This case should never occur. Support it just in case, but log to + // surface that something weird is happening. + OSP_LOG_INFO << "Updating the same record multiple times with multiple " + "TTL values."; + } +} + +void MdnsPublisher::ProcessRecordQueue() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (records_to_send_.empty()) { + return; + } + + MdnsMessage message(CreateMessageId(), MessageType::Response); + for (auto it = records_to_send_.begin(); it != records_to_send_.end();) { + if (message.CanAddRecord(*it)) { + message.AddAnswer(std::move(*it++)); + } else if (message.answers().empty()) { + // This case should never happen, because it means a record is too large + // to fit into its own message. + OSP_LOG << "Encountered unreasonably large message in cache. Skipping " + << "known answer in suppressions..."; + it++; + } else { + sender_->SendMulticast(message); + message = MdnsMessage(CreateMessageId(), MessageType::Response); + } + } + + if (!message.answers().empty()) { + sender_->SendMulticast(message); + } + + batch_records_alarm_ = absl::nullopt; + records_to_send_.clear(); +} + +Clock::duration MdnsPublisher::RecordAnnouncer::GetNextAnnounceDelay() { + return std::chrono::duration_cast<Clock::duration>( + kMinAnnounceDelay * pow(kIntervalIncreaseFactor, attempts_)); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h index 022a94df90f..4b418312104 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h @@ -5,60 +5,193 @@ #ifndef DISCOVERY_MDNS_MDNS_PUBLISHER_H_ #define DISCOVERY_MDNS_MDNS_PUBLISHER_H_ +#include <map> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" #include "discovery/mdns/mdns_records.h" #include "discovery/mdns/mdns_responder.h" +#include "util/alarm.h" namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace discovery { -// TODO(rwkeane): Add API for claiming a DomainName as described in RFC 6762 -// Section 8.1's probing phase. +struct Config; +class MdnsProbeManager; +class MdnsRandom; +class MdnsSender; +class MdnsQuerier; + +// This class is responsible for both tracking what records have been registered +// to mDNS as well as publishing new mDNS records to the network. +// When a new record is published, it will be announced 8 times, starting at an +// interval of 1 second, with the interval doubling each successive +// announcement. This same announcement process is followed when an existing +// record is updated. When it is removed, a Goodbye message must be sent if the +// record is unique. +// +// Prior to publishing a record, the domain name for this service instance must +// be claimed using the ClaimExclusiveOwnership() function. This function probes +// the network to determine whether the chosen name exists, modifying the +// chosen name as described in RFC 6762 if a collision is found. +// +// NOTE: All MdnsPublisher instances must be run on the same task runner thread, +// due to the shared announce + goodbye message queue. class MdnsPublisher : public MdnsResponder::RecordHandler { public: - // |querier|, |sender|, |task_runner|, and |random_delay| must all persist for - // the duration of this object's lifetime - MdnsPublisher(MdnsQuerier* querier, - MdnsSender* sender, - platform::TaskRunner* task_runner, - MdnsRandom* random_delay); - ~MdnsPublisher(); + // |sender|, |ownership_manager|, and |task_runner| must all persist for the + // duration of this object's lifetime + MdnsPublisher(MdnsSender* sender, + MdnsProbeManager* ownership_manager, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + const Config& config); + ~MdnsPublisher() override; // Registers a new mDNS record for advertisement by this service. For A, AAAA, // SRV, and TXT records, the domain name must have already been claimed by the // ClaimExclusiveOwnership() method and for PTR records the name being pointed - // to must have been claimed in the same fashion. + // to must have been claimed in the same fashion, but the domain name in the + // top-level MdnsRecord entity does not. + // NOTE: NSEC records cannot be registered, and doing so will return an error. Error RegisterRecord(const MdnsRecord& record); // Updates the existing record with name matching the name of the new record. + // NOTE: This method is not valid for PTR records. Error UpdateRegisteredRecord(const MdnsRecord& old_record, const MdnsRecord& new_record); - // Stops advertising the provided record. If no more records with the provided - // name are bing advertised after this call's completion, then ownership of - // the name is released. + // Stops advertising the provided record. Error UnregisterRecord(const MdnsRecord& record); + // Returns the total number of records currently registered; + size_t GetRecordCount() const; + OSP_DISALLOW_COPY_AND_ASSIGN(MdnsPublisher); private: + // Class responsible for sending announcement and goodbye messages for + // MdnsRecord instances when they are published, updated, or unpublished. The + // announcement messages will be sent |target_announcement_attempts| times, + // first at an interval of 1 second apart, and then with delay increasing by a + // factor of 2 with each successive announcement. + // NOTE: |publisher| must be the MdnsPublisher instance from which this + // instance was created. + class RecordAnnouncer { + public: + RecordAnnouncer(MdnsRecord record, + MdnsPublisher* publisher, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + int max_announcement_attempts); + RecordAnnouncer(const RecordAnnouncer& other) = delete; + RecordAnnouncer(RecordAnnouncer&& other) noexcept = delete; + ~RecordAnnouncer(); + + RecordAnnouncer& operator=(const RecordAnnouncer& other) = delete; + RecordAnnouncer& operator=(RecordAnnouncer&& other) noexcept = delete; + + const MdnsRecord& record() const { return record_; } + + // Specifies whether goodbye messages should not be sent when this announcer + // is destroyed. This should only be called as part of the 'Update' flow, + // for records which should not send this message. + void DisableGoodbyeMessageTransmission() { + should_send_goodbye_message_ = false; + } + + private: + // Gets the delay required before the next announcement message is sent. + Clock::duration GetNextAnnounceDelay(); + + // When announce + goodbye messages are ready to be sent, they are queued + // up. Every 20ms, if there are any messages to send out, these records are + // batched up and sent out. + void QueueGoodbye(); + void QueueAnnouncement(); + + MdnsPublisher* const publisher_; + TaskRunner* const task_runner_; + const ClockNowFunctionPtr now_function_; + + // Whether or not goodbye messages should be sent. + bool should_send_goodbye_message_ = true; + + // Record to send. + const MdnsRecord record_; + + // Alarm used to cancel future resend attempts if this object is deleted. + Alarm alarm_; + + // Number of attempts at sending this record which have occurred so far. + int attempts_ = 0; + + // Number of times to announce a newly published record. + const int target_announcement_attempts_; + }; + + using RecordAnnouncerPtr = std::unique_ptr<RecordAnnouncer>; + + friend class MdnsPublisherTesting; + + // Creates a new published from the provided record. + RecordAnnouncerPtr CreateAnnouncer(MdnsRecord record) { + return std::make_unique<RecordAnnouncer>(std::move(record), this, + task_runner_, now_function_, + max_announcement_attempts_); + } + + // Removes the given record from the |records_| map. A goodbye record is only + // sent for this removal if |should_announce_deletion| is true. + Error RemoveRecord(const MdnsRecord& record, bool should_announce_deletion); + + // Returns whether the provided record has had its name claimed so far. + bool IsRecordNameClaimed(const MdnsRecord& record) const; + + // Processes the |records_to_send_| queue, sending out the records together as + // a single MdnsMessage. + void ProcessRecordQueue(); + + // Adds a new record to the |records_to_send_| queue or ensures that the + // record with lower ttl is present if it differs from an existing record by + // only that one field. + void QueueRecord(MdnsRecord record); + // MdnsResponder::RecordHandler overrides. - bool IsExclusiveOwner(const DomainName& name) override; bool HasRecords(const DomainName& name, DnsType type, DnsClass clazz) override; std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name, DnsType type, DnsClass clazz) override; + std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override; - MdnsQuerier* const querier_; MdnsSender* const sender_; - platform::TaskRunner* const task_runner_; - MdnsRandom* const random_delay_; + MdnsProbeManager* const ownership_manager_; + TaskRunner* const task_runner_; + ClockNowFunctionPtr now_function_; + + // Alarm to cancel batching of records when this class is destroyed, and + // instead send them immediately. Variable is only set when it is in use. + absl::optional<Alarm> batch_records_alarm_; + + // Number of times to announce a newly published record. + const int max_announcement_attempts_; + + // The queue for announce and goodbye records to be sent periodically. + std::vector<MdnsRecord> records_to_send_; + + // Stores mDNS records that have been published. The keys here are domain + // names for valid mDNS Records, and the values are the RecordAnnouncer + // entities associated with all published MdnsRecords for the keyed domain. + // These are responsible for publishing a specific MdnsRecord, announcing it + // when its created and sending a goodbye record when it's deleted. + std::map<DomainName, std::vector<RecordAnnouncerPtr>> records_; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc new file mode 100644 index 00000000000..c05019f7468 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc @@ -0,0 +1,465 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_publisher.h" + +#include <chrono> // NOLINT +#include <vector> + +#include "discovery/common/config.h" +#include "discovery/mdns/mdns_probe_manager.h" +#include "discovery/mdns/mdns_sender.h" +#include "discovery/mdns/testing/mdns_test_util.h" +#include "platform/test/fake_task_runner.h" +#include "platform/test/fake_udp_socket.h" + +using testing::_; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +namespace openscreen { +namespace discovery { +namespace { + +constexpr Clock::duration kAnnounceGoodbyeDelay = std::chrono::milliseconds(25); + +bool ContainsRecord(const std::vector<MdnsRecord::ConstRef>& records, + MdnsRecord record) { + return std::find_if(records.begin(), records.end(), + [&record](const MdnsRecord& ref) { + return ref == record; + }) != records.end(); +} + +} // namespace + +class MockMdnsSender : public MdnsSender { + public: + explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} + + MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); + MOCK_METHOD2(SendMessage, + Error(const MdnsMessage& message, const IPEndpoint& endpoint)); +}; + +class MockProbeManager : public MdnsProbeManager { + public: + MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&)); + MOCK_METHOD2(RespondToProbeQuery, + void(const MdnsMessage&, const IPEndpoint&)); +}; + +class MdnsPublisherTesting : public MdnsPublisher { + public: + using MdnsPublisher::GetPtrRecords; + using MdnsPublisher::GetRecords; + using MdnsPublisher::MdnsPublisher; + + bool IsNonPtrRecordPresent(const DomainName& name) { + auto it = records_.find(name); + if (it == records_.end()) { + return false; + } + + return std::find_if(it->second.begin(), it->second.end(), + [](const RecordAnnouncerPtr& announcer) { + return announcer->record().dns_type() != + DnsType::kPTR; + }) != it->second.end(); + } +}; + +class MdnsPublisherTest : public testing::Test { + public: + MdnsPublisherTest() + : clock_(Clock::now()), + task_runner_(&clock_), + socket_(&task_runner_), + sender_(&socket_), + publisher_(&sender_, + &probe_manager_, + &task_runner_, + FakeClock::now, + config_) {} + + ~MdnsPublisherTest() { + // Clear out any remaining calls in the task runner queue. + clock_.Advance( + std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1))); + } + + protected: + Error IsAnnounced(const MdnsRecord& original, const MdnsMessage& message) { + EXPECT_EQ(message.type(), MessageType::Response); + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{0}); + EXPECT_EQ(message.answers().size(), size_t{1}); + + const MdnsRecord& sent = message.answers()[0]; + EXPECT_EQ(original.name(), sent.name()); + EXPECT_EQ(original.dns_type(), sent.dns_type()); + EXPECT_EQ(original.dns_class(), sent.dns_class()); + EXPECT_EQ(original.record_type(), sent.record_type()); + EXPECT_EQ(original.rdata(), sent.rdata()); + EXPECT_EQ(original.ttl(), sent.ttl()); + return Error::None(); + } + + Error IsGoodbyeRecord(const MdnsRecord& original, + const MdnsMessage& message) { + EXPECT_EQ(message.type(), MessageType::Response); + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{0}); + EXPECT_EQ(message.answers().size(), size_t{1}); + + const MdnsRecord& sent = message.answers()[0]; + EXPECT_EQ(original.name(), sent.name()); + EXPECT_EQ(original.dns_type(), sent.dns_type()); + EXPECT_EQ(original.dns_class(), sent.dns_class()); + EXPECT_EQ(original.record_type(), sent.record_type()); + EXPECT_EQ(original.rdata(), sent.rdata()); + EXPECT_EQ(std::chrono::seconds(0), sent.ttl()); + return Error::None(); + } + + void CheckPublishedRecords(const DomainName& domain, + DnsType type, + std::vector<MdnsRecord> expected_records) { + EXPECT_EQ(publisher_.GetRecordCount(), expected_records.size()); + auto records = publisher_.GetRecords(domain, type, DnsClass::kIN); + for (const auto& record : expected_records) { + EXPECT_TRUE(ContainsRecord(records, record)); + } + } + + void TestUniqueRecordRegistrationWorkflow(MdnsRecord record, + MdnsRecord record2) { + EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_)) + .WillRepeatedly(Return(true)); + DnsType type = record.dns_type(); + + // Check preconditions. + ASSERT_EQ(record.dns_type(), record2.dns_type()); + auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN); + ASSERT_EQ(publisher_.GetRecordCount(), size_t{0}); + ASSERT_EQ(records.size(), size_t{0}); + ASSERT_NE(record, record2); + ASSERT_TRUE(records.empty()); + + // Register a new record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + EXPECT_TRUE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Re-register the same record. + EXPECT_FALSE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Update a record that doesn't exist + EXPECT_FALSE(publisher_.UpdateRegisteredRecord(record2, record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Update an existing record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record2](const MdnsMessage& message) -> Error { + return IsAnnounced(record2, message); + }); + EXPECT_TRUE(publisher_.UpdateRegisteredRecord(record, record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record2}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Add back the original record + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + EXPECT_TRUE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record, record2}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Delete an existing record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record2](const MdnsMessage& message) -> Error { + return IsGoodbyeRecord(record2, message); + }); + EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Delete a non-existing record. + EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {record}); + EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_)); + + // Delete the last record + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsGoodbyeRecord(record, message); + }); + EXPECT_TRUE(publisher_.UnregisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(domain_, type, {}); + EXPECT_FALSE(publisher_.IsNonPtrRecordPresent(domain_)); + } + + FakeClock clock_; + FakeTaskRunner task_runner_; + FakeUdpSocket socket_; + StrictMock<MockMdnsSender> sender_; + StrictMock<MockProbeManager> probe_manager_; + Config config_; + MdnsPublisherTesting publisher_; + + DomainName domain_{"instance", "_googlecast", "_tcp", "local"}; + DomainName ptr_domain_{"_googlecast", "_tcp", "local"}; +}; + +TEST_F(MdnsPublisherTest, ARecordRegistrationWorkflow) { + const MdnsRecord record1 = GetFakeARecord(domain_); + const MdnsRecord record2 = + GetFakeARecord(domain_, std::chrono::seconds(1000)); + TestUniqueRecordRegistrationWorkflow(record1, record2); +} + +TEST_F(MdnsPublisherTest, AAAARecordRegistrationWorkflow) { + const MdnsRecord record1 = GetFakeAAAARecord(domain_); + const MdnsRecord record2 = + GetFakeAAAARecord(domain_, std::chrono::seconds(1000)); + TestUniqueRecordRegistrationWorkflow(record1, record2); +} + +TEST_F(MdnsPublisherTest, TXTRecordRegistrationWorkflow) { + const MdnsRecord record1 = GetFakeTxtRecord(domain_); + const MdnsRecord record2 = + GetFakeTxtRecord(domain_, std::chrono::seconds(1000)); + TestUniqueRecordRegistrationWorkflow(record1, record2); +} + +TEST_F(MdnsPublisherTest, SRVRecordRegistrationWorkflow) { + const MdnsRecord record1 = GetFakeSrvRecord(domain_); + const MdnsRecord record2 = + GetFakeSrvRecord(domain_, std::chrono::seconds(1000)); + TestUniqueRecordRegistrationWorkflow(record1, record2); +} + +TEST_F(MdnsPublisherTest, PTRRecordRegistrationWorkflow) { + const MdnsRecord record = GetFakePtrRecord(domain_); + const MdnsRecord record2 = + GetFakePtrRecord(domain_, std::chrono::seconds(1000)); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_)) + .WillRepeatedly(Return(true)); + DnsType type = DnsType::kPTR; + + // Check preconditions. + ASSERT_EQ(record.dns_type(), record2.dns_type()); + ASSERT_EQ(publisher_.GetRecordCount(), size_t{0}); + auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN); + ASSERT_EQ(records.size(), size_t{0}); + records = publisher_.GetRecords(ptr_domain_, type, DnsClass::kIN); + ASSERT_EQ(records.size(), size_t{0}); + ASSERT_NE(record, record2); + ASSERT_TRUE(records.empty()); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0}); + + // Register a new record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + EXPECT_TRUE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {record}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1}); + + // Re-register the same record. + EXPECT_FALSE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {record}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1}); + + // Register a second record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record2](const MdnsMessage& message) -> Error { + return IsAnnounced(record2, message); + }); + EXPECT_TRUE(publisher_.RegisterRecord(record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {record, record2}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{2}); + + // Delete an existing record. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record2](const MdnsMessage& message) -> Error { + return IsGoodbyeRecord(record2, message); + }); + EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {record}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1}); + + // Delete a non-existing record. + EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {record}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1}); + + // Delete the last record + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsGoodbyeRecord(record, message); + }); + EXPECT_TRUE(publisher_.UnregisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + CheckPublishedRecords(ptr_domain_, type, {}); + ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0}); +} + +TEST_F(MdnsPublisherTest, RegisteringUnownedRecordsFail) { + EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_)) + .WillRepeatedly(Return(false)); + EXPECT_FALSE(publisher_.RegisterRecord(GetFakePtrRecord(domain_)).ok()); + EXPECT_FALSE(publisher_.RegisterRecord(GetFakeSrvRecord(domain_)).ok()); + EXPECT_FALSE(publisher_.RegisterRecord(GetFakeTxtRecord(domain_)).ok()); + EXPECT_FALSE(publisher_.RegisterRecord(GetFakeARecord(domain_)).ok()); + EXPECT_FALSE(publisher_.RegisterRecord(GetFakeAAAARecord(domain_)).ok()); +} + +TEST_F(MdnsPublisherTest, RegistrationAnnouncesEightTimes) { + EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_)) + .WillRepeatedly(Return(true)); + constexpr Clock::duration kOneSecond = + std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)); + + // First announce, at registration. + const MdnsRecord record = GetFakeARecord(domain_); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + EXPECT_TRUE(publisher_.RegisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); + + // Second announce, at 2 seconds. + testing::Mock::VerifyAndClearExpectations(&sender_); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Third announce, at 4 seconds. + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Fourth announce, at 8 seconds. + clock_.Advance(kOneSecond * 3); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Fifth announce, at 16 seconds. + clock_.Advance(kOneSecond * 7); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Sixth announce, at 32 seconds. + clock_.Advance(kOneSecond * 15); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Seventh announce, at 64 seconds. + clock_.Advance(kOneSecond * 31); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Eighth announce, at 128 seconds. + clock_.Advance(kOneSecond * 63); + clock_.Advance(kAnnounceGoodbyeDelay); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsAnnounced(record, message); + }); + clock_.Advance(kOneSecond); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // No more announcements + clock_.Advance(kOneSecond * 1024); + clock_.Advance(kAnnounceGoodbyeDelay); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // Sends goodbye message when removed. + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &record](const MdnsMessage& message) -> Error { + return IsGoodbyeRecord(record, message); + }); + EXPECT_TRUE(publisher_.UnregisterRecord(record).ok()); + clock_.Advance(kAnnounceGoodbyeDelay); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc index 3c4c635369e..a16b7ef873c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc @@ -4,36 +4,230 @@ #include "discovery/mdns/mdns_querier.h" +#include <vector> + +#include "discovery/common/config.h" +#include "discovery/common/reporting_client.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_sender.h" -#include "discovery/mdns/mdns_trackers.h" namespace openscreen { namespace discovery { +namespace { + +const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = { + DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV}; + +bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) { + if (record.dns_type() != DnsType::kNSEC) { + return false; + } + + const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata()); + + // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC + // record to indicate this is an mDNS NSEC record rather than a traditional + // DNS NSEC record. + if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) != + nsec.types().end()) { + return false; + } + + return std::find_if(nsec.types().begin(), nsec.types().end(), + [type](DnsType stored_type) { + return stored_type == type || + stored_type == DnsType::kANY; + }) != nsec.types().end(); +} + +} // namespace + +MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache( + MdnsQuerier* querier, + MdnsSender* sender, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + ReportingClient* reporting_client, + const Config& config) + : querier_(querier), + sender_(sender), + random_delay_(random_delay), + task_runner_(task_runner), + now_function_(now_function), + reporting_client_(reporting_client), + config_(config) { + OSP_DCHECK(sender_); + OSP_DCHECK(random_delay_); + OSP_DCHECK(task_runner_); + OSP_DCHECK(reporting_client_); + OSP_DCHECK_GT(config_.querier_max_records_cached, 0); +} + +std::vector<std::reference_wrapper<const MdnsRecordTracker>> +MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) { + return Find(name, DnsType::kANY, DnsClass::kANY); +} + +std::vector<std::reference_wrapper<const MdnsRecordTracker>> +MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name, + DnsType dns_type, + DnsClass dns_class) { + std::vector<RecordTrackerConstRef> results; + auto pair = records_.equal_range(name); + for (auto it = pair.first; it != pair.second; it++) { + const MdnsRecordTracker& tracker = *it->second; + if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) && + (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) { + results.push_back(std::cref(tracker)); + } + } + + return results; +} + +int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain, + TrackerApplicableCheck check) { + auto pair = records_.equal_range(domain); + int count = 0; + for (RecordMap::iterator it = pair.first; it != pair.second;) { + if (check(*it->second)) { + lru_order_.erase(it->second); + it = records_.erase(it); + count++; + } else { + it++; + } + } + + return count; +} + +int MdnsQuerier::RecordTrackerLruCache::ExpireSoon( + const DomainName& domain, + TrackerApplicableCheck check) { + auto pair = records_.equal_range(domain); + int count = 0; + for (RecordMap::iterator it = pair.first; it != pair.second; it++) { + if (check(*it->second)) { + MoveToEnd(it); + it->second->ExpireSoon(); + count++; + } + } + + return count; +} + +int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record, + TrackerApplicableCheck check) { + return Update(record, check, [](const MdnsRecordTracker& t) {}); +} + +int MdnsQuerier::RecordTrackerLruCache::Update( + const MdnsRecord& record, + TrackerApplicableCheck check, + TrackerChangeCallback on_rdata_update) { + auto pair = records_.equal_range(record.name()); + int count = 0; + for (RecordMap::iterator it = pair.first; it != pair.second; it++) { + if (check(*it->second)) { + auto result = it->second->Update(record); + + if (result.is_error()) { + reporting_client_->OnRecoverableError( + Error(Error::Code::kUpdateReceivedRecordFailure, + result.error().ToString())); + continue; + } + + count++; + if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) { + it->second->ExpireSoon(); + MoveToEnd(it); + } else { + MoveToBeginning(it); + if (result.value() == MdnsRecordTracker::UpdateType::kRdata) { + on_rdata_update(*it->second); + } + } + } + } + + return count; +} + +const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking( + MdnsRecord record, + DnsType dns_type) { + auto expiration_callback = [this](const MdnsRecordTracker* tracker, + const MdnsRecord& record) { + querier_->OnRecordExpired(tracker, record); + }; + + while (lru_order_.size() >= + static_cast<size_t>(config_.querier_max_records_cached)) { + // This call erases one of the tracked records. + OSP_DVLOG << "Maximum cacheable record count exceeded (" + << config_.querier_max_records_cached << ")"; + lru_order_.back().ExpireNow(); + } + + auto name = record.name(); + lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_, + now_function_, random_delay_, + std::move(expiration_callback)); + records_.emplace(std::move(name), lru_order_.begin()); + + return lru_order_.front(); +} + +void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning( + MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) { + lru_order_.splice(lru_order_.begin(), lru_order_, it->second); + it->second = lru_order_.begin(); +} + +void MdnsQuerier::RecordTrackerLruCache::MoveToEnd( + MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) { + lru_order_.splice(lru_order_.end(), lru_order_, it->second); + it->second = --lru_order_.end(); +} MdnsQuerier::MdnsQuerier(MdnsSender* sender, MdnsReceiver* receiver, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay) + MdnsRandom* random_delay, + ReportingClient* reporting_client, + Config config) : sender_(sender), receiver_(receiver), task_runner_(task_runner), now_function_(now_function), - random_delay_(random_delay) { + random_delay_(random_delay), + reporting_client_(reporting_client), + config_(std::move(config)), + records_(this, + sender_, + random_delay_, + task_runner_, + now_function_, + reporting_client_, + config_) { OSP_DCHECK(sender_); OSP_DCHECK(receiver_); OSP_DCHECK(task_runner_); OSP_DCHECK(now_function_); OSP_DCHECK(random_delay_); + OSP_DCHECK(reporting_client_); - receiver_->SetResponseCallback( - [this](const MdnsMessage& message) { OnMessageReceived(message); }); + receiver_->AddResponseCallback(this); } MdnsQuerier::~MdnsQuerier() { - receiver_->SetResponseCallback(nullptr); + receiver_->RemoveResponseCallback(this); } // NOTE: The code below is range loops instead of std:find_if, for better @@ -46,6 +240,7 @@ void MdnsQuerier::StartQuery(const DomainName& name, MdnsRecordChangedCallback* callback) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(callback); + OSP_DCHECK(dns_type != DnsType::kNSEC); // Add a new callback if haven't seen it before auto callbacks_it = callbacks_.equal_range(name); @@ -63,12 +258,15 @@ void MdnsQuerier::StartQuery(const DomainName& name, // Notify the new callback with previously cached records. // NOTE: In the future, could allow callers to fetch cached records after // adding a callback, for example to prime the UI. - auto records_it = records_.equal_range(name); - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - const MdnsRecord& record = entry->second->record(); - if ((dns_type == DnsType::kANY || dns_type == record.dns_type()) && - (dns_class == DnsClass::kANY || dns_class == record.dns_class())) { - callback->OnRecordChanged(record, RecordChangedEvent::kCreated); + const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers = + records_.Find(name, dns_type, dns_class); + for (const MdnsRecordTracker& tracker : trackers) { + if (!tracker.is_negative_response()) { + MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(), + tracker.record_type(), tracker.ttl(), + tracker.rdata()); + callback->OnRecordChanged(std::move(stored_record), + RecordChangedEvent::kCreated); } } @@ -92,6 +290,7 @@ void MdnsQuerier::StopQuery(const DomainName& name, MdnsRecordChangedCallback* callback) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(callback); + OSP_DCHECK(dns_type != DnsType::kNSEC); // Find and remove the callback. int callbacks_for_key = 0; @@ -132,154 +331,275 @@ void MdnsQuerier::StopQuery(const DomainName& name, // be configurable by the caller. } +void MdnsQuerier::ReinitializeQueries(const DomainName& name) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + // Get the ongoing queries and their callbacks. + std::vector<CallbackInfo> callbacks; + auto its = callbacks_.equal_range(name); + for (auto it = its.first; it != its.second; it++) { + callbacks.push_back(std::move(it->second)); + } + callbacks_.erase(name); + + // Remove all known questions and answers. + questions_.erase(name); + records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; }); + + // Restart the queries. + for (const auto& cb : callbacks) { + StartQuery(name, cb.dns_type, cb.dns_class, cb.callback); + } +} + void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(message.type() == MessageType::Response); - // TODO(crbug.com/openscreen/83): Drop answers and additional records if - // answer records do not answer any existing questions - // TODO(crbug.com/openscreen/83): Check authority records - // TODO(crbug.com/openscreen/84): Cap size of cache, to avoid memory blowups - // when publishers misbehave. - ProcessRecords(message.answers()); - ProcessRecords(message.additional_records()); + OSP_DVLOG << "Received mDNS Response message with " + << message.answers().size() << " answers and " + << message.additional_records().size() + << " additional records. Processing..."; + + // Add any records that are relevant for this querier. + bool found_relevant_records = false; + int processed_count = 0; + for (const MdnsRecord& record : message.answers()) { + if (ShouldAnswerRecordBeProcessed(record)) { + ProcessRecord(record); + OSP_DVLOG << "\tProcessing answer record for domain '" + << record.name().ToString() << "' of type '" + << record.dns_type() << "'..."; + found_relevant_records = true; + processed_count++; + } + } + + // If any of the message's answers are relevant, add all additional records. + // Else, since the message has already been received and parsed, use any + // individual records relevant to this querier to update the cache. + for (const MdnsRecord& record : message.additional_records()) { + if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) { + OSP_DVLOG << "\tProcessing additional record for domain '" + << record.name().ToString() << "' of type '" + << record.dns_type() << "'..."; + ProcessRecord(record); + processed_count++; + } + } + + OSP_DVLOG << "\tmDNS Response processed (" << processed_count + << " records accepted)!"; + + // TODO(crbug.com/openscreen/83): Check authority records. } -void MdnsQuerier::OnRecordExpired(const MdnsRecord& record) { - OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); +bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) { + // First, accept the record if it's associated with an ongoing question. + const auto questions_range = questions_.equal_range(answer.name()); + const auto it = std::find_if( + questions_range.first, questions_range.second, + [&answer](const auto& pair) { + return (pair.second->question().dns_type() == DnsType::kANY || + IsNegativeResponseFor(answer, + pair.second->question().dns_type()) || + pair.second->question().dns_type() == answer.dns_type()) && + (pair.second->question().dns_class() == DnsClass::kANY || + pair.second->question().dns_class() == answer.dns_class()); + }); + if (it != questions_range.second) { + return true; + } + + // If not, check if it corresponds to an already existing record. This is + // required because records which are already stored may either have been + // received in an additional records section, or are associated with a query + // which is no longer active. + std::vector<DnsType> types{answer.dns_type()}; + if (answer.dns_type() == DnsType::kNSEC) { + const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata()); + types = nsec_rdata.types(); + } - ProcessCallbacks(record, RecordChangedEvent::kExpired); - - auto records_it = records_.equal_range(record.name()); - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class() && - record.rdata() == tracked_record.rdata()) { - records_.erase(entry); - break; + for (DnsType type : types) { + std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers = + records_.Find(answer.name(), type, answer.dns_class()); + if (!trackers.empty()) { + return true; } } + + // In all other cases, the record isn't relevant. Drop it. + return false; +} + +void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker, + const MdnsRecord& record) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + if (!tracker->is_negative_response()) { + ProcessCallbacks(record, RecordChangedEvent::kExpired); + } + + records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) { + return tracker == &it_tracker; + }); } -void MdnsQuerier::ProcessRecords(const std::vector<MdnsRecord>& records) { +void MdnsQuerier::ProcessRecord(const MdnsRecord& record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - for (const MdnsRecord& record : records) { + // Get the types which the received record is associated with. In most cases + // this will only be the type of the provided record, but in the case of + // NSEC records this will be all records which the record dictates the + // nonexistence of. + std::vector<DnsType> types; + const std::vector<DnsType>* types_ptr = &types; + if (record.dns_type() == DnsType::kNSEC) { + const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata()); + if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(), + DnsType::kANY) != nsec_rdata.types().end()) { + types_ptr = &kTranslatedNsecAnyQueryTypes; + } else { + types_ptr = &nsec_rdata.types(); + } + } else { + types.push_back(record.dns_type()); + } + + // Apply the update for each type that the record is associated with. + for (DnsType dns_type : *types_ptr) { switch (record.record_type()) { case RecordType::kShared: { - ProcessSharedRecord(record); + ProcessSharedRecord(record, dns_type); break; } case RecordType::kUnique: { - ProcessUniqueRecord(record); + ProcessUniqueRecord(record, dns_type); break; } } } } -void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) { +void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record, + DnsType dns_type) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(record.record_type() == RecordType::kShared); - auto records_it = records_.equal_range(record.name()); - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class() && - record.rdata() == tracked_record.rdata()) { - // Already have this shared record, update the existing one. - // This is a TTL only update since we've already checked that RDATA - // matches. No notification is necessary on a TTL only update. - // TODO(crbug.com/openscreen/87): Handle errors returned by Update(). - tracker->Update(record); - return; - } + // By design, NSEC records are never shared records. + if (record.dns_type() == DnsType::kNSEC) { + return; + } + + // For any records updated, this host already has this shared record. Since + // the RDATA matches, this is only a TTL update. + auto check = [&record](const MdnsRecordTracker& tracker) { + return record.dns_type() == tracker.dns_type() && + record.dns_class() == tracker.dns_class() && + record.rdata() == tracker.rdata(); + }; + auto updated_count = records_.Update(record, std::move(check)); + + if (!updated_count) { + // Have never before seen this shared record, insert a new one. + AddRecord(record, dns_type); + ProcessCallbacks(record, RecordChangedEvent::kCreated); } - // Have never before seen this shared record, insert a new one. - AddRecord(record); - ProcessCallbacks(record, RecordChangedEvent::kCreated); } -void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) { +void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record, + DnsType dns_type) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(record.record_type() == RecordType::kUnique); - int records_for_key = 0; - auto records_it = records_.equal_range(record.name()); - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - const MdnsRecord& tracked_record = entry->second->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class()) { - ++records_for_key; + std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers = + records_.Find(record.name(), dns_type, record.dns_class()); + size_t num_records_for_key = trackers.size(); + + // Have not seen any records with this key before. This case is expected the + // first time a record is received. + if (num_records_for_key == size_t{0}) { + const bool will_exist = record.dns_type() != DnsType::kNSEC; + AddRecord(record, dns_type); + if (will_exist) { + ProcessCallbacks(record, RecordChangedEvent::kCreated); } } - if (records_for_key == 0) { - // Have not seen any records with this key before. - AddRecord(record); - ProcessCallbacks(record, RecordChangedEvent::kCreated); - } else if (records_for_key == 1) { - // There's only one record with this key. - MdnsRecordTracker* tracker = records_it.first->second.get(); - // TODO(crbug.com/openscreen/87): Handle errors returned by Update(). - ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record); - if (result.is_value()) { - switch (result.value()) { - case MdnsRecordTracker::UpdateType::kGoodbye: - tracker->ExpireSoon(); - break; - case MdnsRecordTracker::UpdateType::kTTLOnly: - // TTL has been updated. No action required. - break; - case MdnsRecordTracker::UpdateType::kRdata: - // If RDATA on the record is different, notify that the record has - // been updated. - ProcessCallbacks(record, RecordChangedEvent::kUpdated); - break; - } - } - } else { - // Multiple records with the same key. Expire all record with non-matching - // RDATA. Update the record with the matching RDATA if it exists, otherwise - // insert a new record. - bool is_new_record = true; - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class()) { - if (record.rdata() == tracked_record.rdata()) { - is_new_record = false; - // TODO(crbug.com/openscreen/87): Handle errors returned by Update(). - ErrorOr<MdnsRecordTracker::UpdateType> result = - tracker->Update(record); - if (result.is_value()) { - switch (result.value()) { - case MdnsRecordTracker::UpdateType::kGoodbye: - tracker->ExpireSoon(); - break; - case MdnsRecordTracker::UpdateType::kTTLOnly: - // No notification is necessary on a TTL only update. - break; - case MdnsRecordTracker::UpdateType::kRdata: - // Not possible - we already checked that the RDATA matches. - OSP_NOTREACHED(); - break; - } - } - } else { - tracker->ExpireSoon(); - } - } + // There is exactly one tracker associated with this key. This is the expected + // case when a record matching this one has already been seen. + else if (num_records_for_key == size_t{1}) { + ProcessSinglyTrackedUniqueRecord(record, trackers[0]); + } + + // Multiple records with the same key. + else { + ProcessMultiTrackedUniqueRecord(record, dns_type); + } +} + +void MdnsQuerier::ProcessSinglyTrackedUniqueRecord( + const MdnsRecord& record, + const MdnsRecordTracker& tracker) { + const bool existed_previously = !tracker.is_negative_response(); + const bool will_exist = record.dns_type() != DnsType::kNSEC; + + // Calculate the callback to call on record update success while the old + // record still exists. + MdnsRecord record_for_callback = record; + if (existed_previously && !will_exist) { + record_for_callback = + MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(), + tracker.record_type(), tracker.ttl(), tracker.rdata()); + } + + auto on_rdata_change = [this, r = std::move(record_for_callback), + existed_previously, + will_exist](const MdnsRecordTracker& tracker) { + // If RDATA on the record is different, notify that the record has + // been updated. + if (existed_previously && will_exist) { + ProcessCallbacks(r, RecordChangedEvent::kUpdated); + } else if (existed_previously) { + // Do not expire the tracker, because it still holds an NSEC record. + ProcessCallbacks(r, RecordChangedEvent::kExpired); + } else if (will_exist) { + ProcessCallbacks(r, RecordChangedEvent::kCreated); } + }; + + int updated_count = records_.Update( + record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; }, + std::move(on_rdata_change)); + OSP_DCHECK_EQ(updated_count, 1); +} - if (is_new_record) { - // Did not find an existing record to update. - AddRecord(record); +void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record, + DnsType dns_type) { + auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) { + return tracker.dns_type() == dns_type && + tracker.dns_class() == record.dns_class() && + tracker.rdata() == record.rdata(); + }; + int update_count = records_.Update( + record, std::move(update_check), + [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); }); + OSP_DCHECK_LE(update_count, 1); + + auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) { + return tracker.dns_type() == dns_type && + tracker.dns_class() == record.dns_class() && + tracker.rdata() != record.rdata(); + }; + int expire_count = + records_.ExpireSoon(record.name(), std::move(expire_check)); + OSP_DCHECK_GE(expire_count, 1); + + // Did not find an existing record to update. + if (!update_count && !expire_count) { + AddRecord(record, dns_type); + if (record.dns_type() != DnsType::kNSEC) { ProcessCallbacks(record, RecordChangedEvent::kCreated); } } @@ -299,25 +619,45 @@ void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record, callback_info.callback->OnRecordChanged(record, event); } } - - // TODO(crbug.com/openscreen/83): Update known answers for relevant questions. } void MdnsQuerier::AddQuestion(const MdnsQuestion& question) { - questions_.emplace(question.name(), - std::make_unique<MdnsQuestionTracker>( - std::move(question), sender_, task_runner_, - now_function_, random_delay_)); + auto tracker = std::make_unique<MdnsQuestionTracker>( + std::move(question), sender_, task_runner_, now_function_, random_delay_, + config_); + MdnsQuestionTracker* ptr = tracker.get(); + questions_.emplace(question.name(), std::move(tracker)); + + // Let all records associated with this question know that there is a new + // query that can be used for their refresh. + std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers = + records_.Find(question.name(), question.dns_type(), question.dns_class()); + for (const MdnsRecordTracker& tracker : trackers) { + // NOTE: When the pointed to object is deleted, its dtor removes itself + // from all associated records. + ptr->AddAssociatedRecord(&tracker); + } } -void MdnsQuerier::AddRecord(const MdnsRecord& record) { - auto expiration_callback = [this](const MdnsRecord& record) { - MdnsQuerier::OnRecordExpired(record); - }; - records_.emplace(record.name(), - std::make_unique<MdnsRecordTracker>( - std::move(record), sender_, task_runner_, now_function_, - random_delay_, expiration_callback)); +void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) { + // Add the new record. + const auto& tracker = records_.StartTracking(record, type); + + // Let all questions associated with this record know that there is a new + // record that answers them (for known answer suppression). + auto query_it = questions_.equal_range(record.name()); + for (auto entry = query_it.first; entry != query_it.second; ++entry) { + const MdnsQuestion& query = entry->second->question(); + const bool is_relevant_type = + type == DnsType::kANY || type == query.dns_type(); + const bool is_relevant_class = record.dns_class() == DnsClass::kANY || + record.dns_class() == query.dns_class(); + if (is_relevant_type && is_relevant_class) { + // NOTE: When the pointed to object is deleted, its dtor removes itself + // from all associated queries. + entry->second->AddAssociatedRecord(&tracker); + } + } } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h index 2e50978ffb2..1125815274d 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h @@ -5,40 +5,44 @@ #ifndef DISCOVERY_MDNS_MDNS_QUERIER_H_ #define DISCOVERY_MDNS_MDNS_QUERIER_H_ +#include <list> #include <map> +#include "discovery/common/config.h" +#include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_records.h" +#include "discovery/mdns/mdns_trackers.h" #include "platform/api/task_runner.h" namespace openscreen { namespace discovery { class MdnsRandom; -class MdnsReceiver; class MdnsSender; class MdnsQuestionTracker; class MdnsRecordTracker; +class ReportingClient; -class MdnsQuerier { +class MdnsQuerier : public MdnsReceiver::ResponseClient { public: - using ClockNowFunctionPtr = openscreen::platform::ClockNowFunctionPtr; - using TaskRunner = openscreen::platform::TaskRunner; - MdnsQuerier(MdnsSender* sender, MdnsReceiver* receiver, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay); + MdnsRandom* random_delay, + ReportingClient* reporting_client, + Config config); MdnsQuerier(const MdnsQuerier& other) = delete; MdnsQuerier(MdnsQuerier&& other) noexcept = delete; MdnsQuerier& operator=(const MdnsQuerier& other) = delete; MdnsQuerier& operator=(MdnsQuerier&& other) noexcept = delete; - ~MdnsQuerier(); + ~MdnsQuerier() override; // Starts an mDNS query with the given name, DNS type, and DNS class. Updated // records are passed to |callback|. The caller must ensure |callback| // remains alive while it is registered with a query. + // NOTE: NSEC records cannot be queried for. void StartQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, @@ -52,6 +56,11 @@ class MdnsQuerier { DnsClass dns_class, MdnsRecordChangedCallback* callback); + // Re-initializes the process of service discovery for the provided domain + // name. All ongoing queries for this domain are restarted and any previously + // received query results are discarded. + void ReinitializeQueries(const DomainName& name); + private: struct CallbackInfo { MdnsRecordChangedCallback* const callback; @@ -59,25 +68,138 @@ class MdnsQuerier { const DnsClass dns_class; }; - // Callback passed to MdnsReceiver - void OnMessageReceived(const MdnsMessage& message); + // Represents a Least Recently Used cache of MdnsRecordTrackers. + class RecordTrackerLruCache { + public: + using RecordTrackerConstRef = + std::reference_wrapper<const MdnsRecordTracker>; + using TrackerApplicableCheck = + std::function<bool(const MdnsRecordTracker&)>; + using TrackerChangeCallback = std::function<void(const MdnsRecordTracker&)>; + + RecordTrackerLruCache(MdnsQuerier* querier, + MdnsSender* sender, + MdnsRandom* random_delay, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + ReportingClient* reporting_client, + const Config& config); + + // Returns all trackers with the associated |name| such that its type + // represents a type corresponding to |dns_type| and class corresponding to + // |dns_class|. + std::vector<RecordTrackerConstRef> Find(const DomainName& name); + std::vector<RecordTrackerConstRef> Find(const DomainName& name, + DnsType dns_type, + DnsClass dns_class); + + // Calls ExpireSoon on all record trackers in the provided domain which + // match the provided applicability check. Returns the number of trackers + // marked for expiry. + int ExpireSoon(const DomainName& name, TrackerApplicableCheck check); + + // Erases all record trackers in the provided domain which match the + // provided applicability check. Returns the number of trackers erased. + int Erase(const DomainName& name, TrackerApplicableCheck check); + + // Updates all record trackers in the domain |record.name()| which match the + // provided applicability check using the provided record. Returns the + // number of records successfully updated. + int Update(const MdnsRecord& record, TrackerApplicableCheck check); + int Update(const MdnsRecord& record, + TrackerApplicableCheck check, + TrackerChangeCallback on_rdata_update); + + // Creates a record tracker of the given type associated with the provided + // record. + const MdnsRecordTracker& StartTracking(MdnsRecord record, DnsType type); + + size_t size() { return records_.size(); } + + private: + using LruList = std::list<MdnsRecordTracker>; + using RecordMap = std::multimap<DomainName, LruList::iterator>; + + void MoveToBeginning(RecordMap::iterator iterator); + void MoveToEnd(RecordMap::iterator iterator); + + MdnsQuerier* const querier_; + MdnsSender* const sender_; + MdnsRandom* const random_delay_; + TaskRunner* const task_runner_; + ClockNowFunctionPtr now_function_; + ReportingClient* reporting_client_; + const Config& config_; + + // List of RecordTracker instances used by this instance where the least + // recently updated element (or next to be deleted element) appears at the + // end of the list. + LruList lru_order_; + + // A collection of active known record trackers, each is identified by + // domain name, DNS record type, and DNS record class. Multimap key is + // domain name only to allow easy support for wildcard processing for DNS + // record type and class and allow storing shared records that differ only + // in RDATA. + // + // MdnsRecordTracker instances are stored as unique_ptr so they are not + // moved around in memory when the collection is modified. This allows + // passing a pointer to MdnsQuestionTracker to a task running on the + // TaskRunner. + RecordMap records_; + }; + + friend class MdnsQuerierTest; + + // MdnsReceiver::ResponseClient overrides. + void OnMessageReceived(const MdnsMessage& message) override; + + // Expires the record tracker provided. This callback is passed to owned + // MdnsRecordTracker instances in |records_|. + void OnRecordExpired(const MdnsRecordTracker* tracker, + const MdnsRecord& record); - // Callback passed to owned MdnsRecordTrackers - void OnRecordExpired(const MdnsRecord& record); + // Determines whether a record received by this querier should be processed + // or dropped. + bool ShouldAnswerRecordBeProcessed(const MdnsRecord& answer); - void ProcessRecords(const std::vector<MdnsRecord>& records); - void ProcessSharedRecord(const MdnsRecord& record); - void ProcessUniqueRecord(const MdnsRecord& record); + // Processes any record update, calling into the below methods as needed. + void ProcessRecord(const MdnsRecord& records); + + // Processes a shared record update as a record of type |type|. + void ProcessSharedRecord(const MdnsRecord& record, DnsType type); + + // Processes a unique record update as a record of type |type|. + void ProcessUniqueRecord(const MdnsRecord& record, DnsType type); + + // Called when exactly one tracker is associated with a provided key. + // Determines the type of update being executed by this update call, then + // fires the appropriate callback. + void ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record, + const MdnsRecordTracker& tracker); + + // Called when multiple records are associated with the same key. Expire all + // record with non-matching RDATA. Update the record with the matching RDATA + // if it exists, otherwise insert a new record. + void ProcessMultiTrackedUniqueRecord(const MdnsRecord& record, + DnsType dns_type); + + // Calls all callbacks associated with the provided record. void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event); + // Begins tracking the provided question. void AddQuestion(const MdnsQuestion& question); - void AddRecord(const MdnsRecord& record); + + // Begins tracking the provided record. + void AddRecord(const MdnsRecord& record, DnsType type); MdnsSender* const sender_; MdnsReceiver* const receiver_; TaskRunner* const task_runner_; const ClockNowFunctionPtr now_function_; MdnsRandom* const random_delay_; + ReportingClient* reporting_client_; + Config config_; // A collection of active question trackers, each is uniquely identified by // domain name, DNS record type, and DNS record class. Multimap key is domain @@ -88,14 +210,8 @@ class MdnsQuerier { // TaskRunner. std::multimap<DomainName, std::unique_ptr<MdnsQuestionTracker>> questions_; - // A collection of active known record trackers, each is identified by domain - // name, DNS record type, and DNS record class. Multimap key is domain name - // only to allow easy support for wildcard processing for DNS record type and - // class and allow storing shared records that differ only in RDATA. - // MdnsRecordTracker instances are stored as unique_ptr so they are not moved - // around in memory when the collection is modified. This allows passing a - // pointer to MdnsQuestionTracker to a task running on the TaskRunner. - std::multimap<DomainName, std::unique_ptr<MdnsRecordTracker>> records_; + // Set of records tracked by this querier. + RecordTrackerLruCache records_; // A collection of callbacks passed to StartQuery method. Each is identified // by domain name, DNS record type, and DNS record class, but there can be diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc index a24078221b1..b48c900f771 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc @@ -4,31 +4,31 @@ #include "discovery/mdns/mdns_querier.h" +#include <memory> + +#include "discovery/common/config.h" +#include "discovery/common/testing/mock_reporting_client.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_sender.h" +#include "discovery/mdns/mdns_trackers.h" #include "discovery/mdns/mdns_writer.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/base/udp_packet.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" +#include "platform/test/mock_udp_socket.h" namespace openscreen { namespace discovery { -using openscreen::platform::Clock; -using openscreen::platform::FakeClock; -using openscreen::platform::FakeTaskRunner; -using openscreen::platform::NetworkInterfaceIndex; -using openscreen::platform::TaskRunner; -using openscreen::platform::UdpPacket; -using openscreen::platform::UdpSocket; using testing::_; using testing::Args; using testing::Invoke; using testing::Return; +using testing::StrictMock; using testing::WithArgs; // Only compare NAME, CLASS, TYPE and RDATA @@ -40,27 +40,6 @@ ACTION_P(PartialCompareRecords, expected) { EXPECT_TRUE(actual.rdata() == expected.rdata()); } -class MockUdpSocket : public UdpSocket { - public: - MOCK_METHOD(bool, IsIPv4, (), (const, override)); - MOCK_METHOD(bool, IsIPv6, (), (const, override)); - MOCK_METHOD(IPEndpoint, GetLocalEndpoint, (), (const, override)); - MOCK_METHOD(void, Bind, (), (override)); - MOCK_METHOD(void, - SetMulticastOutboundInterface, - (NetworkInterfaceIndex), - (override)); - MOCK_METHOD(void, - JoinMulticastGroup, - (const IPAddress&, NetworkInterfaceIndex), - (override)); - MOCK_METHOD(void, - SendMessage, - (const void*, size_t, const IPEndpoint&), - (override)); - MOCK_METHOD(void, SetDscp, (DscpMode), (override)); -}; - class MockRecordChangedCallback : public MdnsRecordChangedCallback { public: MOCK_METHOD(void, @@ -75,7 +54,6 @@ class MdnsQuerierTest : public testing::Test { : clock_(Clock::now()), task_runner_(&clock_), sender_(&socket_), - receiver_(&socket_), record0_created_(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, @@ -105,38 +83,89 @@ class MdnsQuerierTest : public testing::Test { DnsClass::kIN, RecordType::kShared, std::chrono::seconds(0), // a goodbye record - ARecordRdata(IPAddress{192, 168, 0, 1})) { + ARecordRdata(IPAddress{192, 168, 0, 1})), + record2_created_(DomainName{"testing", "local"}, + DnsType::kAAAA, + DnsClass::kIN, + RecordType::kUnique, + std::chrono::seconds(120), + AAAARecordRdata(IPAddress{1, 2, 3, 4, 5, 6, 7, 8})), + nsec_record_created_( + DomainName{"testing", "local"}, + DnsType::kNSEC, + DnsClass::kIN, + RecordType::kUnique, + std::chrono::seconds(120), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) { receiver_.Start(); } std::unique_ptr<MdnsQuerier> CreateQuerier() { return std::make_unique<MdnsQuerier>(&sender_, &receiver_, &task_runner_, - &FakeClock::now, &random_); + &FakeClock::now, &random_, + &reporting_client_, config_); } protected: - UdpPacket CreatePacketWithRecord(const MdnsRecord& record) { + UdpPacket CreatePacketWithRecords( + const std::vector<MdnsRecord::ConstRef>& records, + std::vector<MdnsRecord::ConstRef> additional_records) { MdnsMessage message(CreateMessageId(), MessageType::Response); - message.AddAnswer(record); + for (const MdnsRecord& record : records) { + message.AddAnswer(record); + } + for (const MdnsRecord& additional_record : additional_records) { + message.AddAdditionalRecord(additional_record); + } UdpPacket packet(message.MaxWireSize()); MdnsWriter writer(packet.data(), packet.size()); - writer.Write(message); + EXPECT_TRUE(writer.Write(message)); packet.resize(writer.offset()); return packet; } + UdpPacket CreatePacketWithRecords( + const std::vector<MdnsRecord::ConstRef>& records) { + return CreatePacketWithRecords(records, {}); + } + + UdpPacket CreatePacketWithRecord(const MdnsRecord& record) { + return CreatePacketWithRecords({MdnsRecord::ConstRef(record)}); + } + + // NSEC records are never exposed to outside callers, so the below methods are + // necessary to validate that they are functioning as expected. + bool ContainsRecord(MdnsQuerier* querier, + const MdnsRecord& record, + DnsType type = DnsType::kANY) { + auto record_trackers = + querier->records_.Find(record.name(), type, record.dns_class()); + + return std::find_if(record_trackers.begin(), record_trackers.end(), + [&record](const MdnsRecordTracker& tracker) { + return tracker.rdata() == record.rdata() && + tracker.ttl() == record.ttl(); + }) != record_trackers.end(); + } + + size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); } + + Config config_; FakeClock clock_; FakeTaskRunner task_runner_; testing::NiceMock<MockUdpSocket> socket_; MdnsSender sender_; MdnsReceiver receiver_; MdnsRandom random_; + StrictMock<MockReportingClient> reporting_client_; MdnsRecord record0_created_; MdnsRecord record0_updated_; MdnsRecord record0_deleted_; MdnsRecord record1_created_; MdnsRecord record1_deleted_; + MdnsRecord record2_created_; + MdnsRecord nsec_record_created_; }; TEST_F(MdnsQuerierTest, UniqueRecordCreatedUpdatedDeleted) { @@ -330,5 +359,257 @@ TEST_F(MdnsQuerierTest, SameCallerDifferentQuestions) { receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_)); } +TEST_F(MdnsQuerierTest, ReinitializeQueries) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MockRecordChangedCallback callback; + + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + + EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated)) + .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_))); + + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_)); + // Receiving the same record should only reset TTL, no callback + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_)); + testing::Mock::VerifyAndClearExpectations(&receiver_); + + // Queries should still be ongoing but all received records should have been + // deleted. + querier->ReinitializeQueries(DomainName{"testing", "local"}); + EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated)) + .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_))); + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_)); + testing::Mock::VerifyAndClearExpectations(&receiver_); + + // Reinitializing a different domain should not affect other queries. + querier->ReinitializeQueries(DomainName{"testing2", "local"}); + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_)); +} + +TEST_F(MdnsQuerierTest, MessagesForUnknownQueriesDropped) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MockRecordChangedCallback callback; + + // Message for unknown query does not get processed. + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_)); + querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN, + &callback); + testing::Mock::VerifyAndClearExpectations(&callback); + + querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN, + &callback); + + // Only known records from the message are processed. + EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated)) + .Times(1); + receiver_.OnRead( + &socket_, CreatePacketWithRecords({record0_created_, record1_created_})); + querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN, + &callback); +} + +TEST_F(MdnsQuerierTest, MessagesForKnownRecordsAllowed) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MockRecordChangedCallback callback; + + // Store a message for a known query. + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_)); + testing::Mock::VerifyAndClearExpectations(&callback); + + // Stop the query and validate that record updates are still received. + querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, + &callback); + receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_)); + testing::Mock::VerifyAndClearExpectations(&callback); + + querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN, + &callback); + + // Only known records from the message are processed. + EXPECT_CALL(callback, + OnRecordChanged(record0_updated_, RecordChangedEvent::kCreated)) + .Times(1); + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); +} + +TEST_F(MdnsQuerierTest, MessagesForUnknownKnownRecordsAllowsAdditionalRecords) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MockRecordChangedCallback callback; + + // Store a message for a known query. + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)) + .Times(1); + receiver_.OnRead(&socket_, CreatePacketWithRecords({record1_created_}, + {record0_created_})); + testing::Mock::VerifyAndClearExpectations(&callback); +} + +TEST_F(MdnsQuerierTest, CallbackNotCalledOnStartQueryForNsecRecords) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an NSEC record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{1}); + EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + + // Start new query + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); +} + +TEST_F(MdnsQuerierTest, ReceiveNsecRecordFansOutToEachType) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + MdnsRecord multi_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kA, + DnsType::kSRV, DnsType::kAAAA)); + auto packet = CreatePacketWithRecord(multi_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{3}); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kSRV)); +} + +TEST_F(MdnsQuerierTest, ReceiveNsecKAnyRecordFansOutToAllTypes) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + MdnsRecord any_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kANY)); + auto packet = CreatePacketWithRecord(any_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{5}); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kAAAA)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kSRV)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kTXT)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kPTR)); +} + +TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNsecRecordReplacesNonNsec) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)); + auto packet = CreatePacketWithRecord(record0_created_); + receiver_.OnRead(&socket_, std::move(packet)); + testing::Mock::VerifyAndClearExpectations(&callback); + ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kExpired)); + packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); +} + +TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNonNsecRecordReplacesNsec) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)); + packet = CreatePacketWithRecord(record0_created_); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); +} + +TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MdnsRecord multi_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kA, + DnsType::kSRV, DnsType::kAAAA)); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_FALSE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); + + packet = CreatePacketWithRecord(multi_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); +} + +TEST_F(MdnsQuerierTest, TestMaxRecordsRespected) { + config_.querier_max_records_cached = 1; + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY, + DnsClass::kIN, &callback); + querier->StartQuery(DomainName{"poking", "local"}, DnsType::kANY, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(record0_created_); + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{1}); + EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_FALSE(ContainsRecord(querier.get(), record1_created_, DnsType::kA)); + testing::Mock::VerifyAndClearExpectations(&callback); + + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kExpired)); + EXPECT_CALL(callback, + OnRecordChanged(record1_created_, RecordChangedEvent::kCreated)); + packet = CreatePacketWithRecord(record1_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{1}); + EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), record1_created_, DnsType::kA)); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h index b422f998683..bf33f1fdc80 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h @@ -14,8 +14,6 @@ namespace discovery { class MdnsRandom { public: - using Clock = openscreen::platform::Clock; - // RFC 6762 Section 5.2 // https://tools.ietf.org/html/rfc6762#section-5.2 @@ -40,6 +38,11 @@ class MdnsRandom { truncated_query_response_delay_(random_engine_)}; } + Clock::duration GetInitialProbeDelay() { + return std::chrono::milliseconds{ + probe_initialization_delay_(random_engine_)}; + } + private: static constexpr int64_t kMinimumInitialQueryDelayMs = 20; static constexpr int64_t kMaximumInitialQueryDelayMs = 120; @@ -53,6 +56,9 @@ class MdnsRandom { static constexpr int64_t kMinimumTruncatedQueryResponseDelayMs = 400; static constexpr int64_t kMaximumTruncatedQueryResponseDelayMs = 500; + static constexpr int64_t kMinimumProbeInitializationDelayMs = 0; + static constexpr int64_t kMaximumProbeInitializationDelayMs = 250; + std::default_random_engine random_engine_{std::random_device{}()}; std::uniform_int_distribution<int64_t> initial_query_delay_{ kMinimumInitialQueryDelayMs, kMaximumInitialQueryDelayMs}; @@ -63,6 +69,8 @@ class MdnsRandom { std::uniform_int_distribution<int64_t> truncated_query_response_delay_{ kMinimumTruncatedQueryResponseDelayMs, kMaximumTruncatedQueryResponseDelayMs}; + std::uniform_int_distribution<int64_t> probe_initialization_delay_{ + kMinimumProbeInitializationDelayMs, kMaximumProbeInitializationDelayMs}; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc index 1ba58c90bdf..e0ad261154d 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc @@ -10,8 +10,6 @@ namespace openscreen { namespace discovery { -using openscreen::platform::Clock; - namespace { constexpr int kIterationCount = 100; } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc index 21174e742ef..e7de282769a 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc @@ -7,10 +7,25 @@ #include <algorithm> #include <utility> +#include "discovery/mdns/public/mdns_constants.h" #include "util/logging.h" namespace openscreen { namespace discovery { +namespace { + +bool TryParseDnsType(uint16_t to_parse, DnsType* type) { + auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(), + static_cast<DnsType>(to_parse)); + if (it == kSupportedDnsTypes.end()) { + return false; + } + + *type = *it; + return true; +} + +} // namespace bool MdnsReader::Read(TxtRecordRdata::Entry* out) { Cursor cursor(this); @@ -52,7 +67,12 @@ bool MdnsReader::Read(DomainName* out) { bytes_processed <= length()) { const uint8_t label_type = ReadBigEndian<uint8_t>(position); if (IsTerminationLabel(label_type)) { - *out = DomainName(labels); + ErrorOr<DomainName> domain = + DomainName::TryCreate(labels.begin(), labels.end()); + if (domain.is_error()) { + return false; + } + *out = std::move(domain.value()); if (!bytes_consumed) { bytes_consumed = position + sizeof(uint8_t) - current(); } @@ -100,7 +120,12 @@ bool MdnsReader::Read(RawRecordRdata* out) { if (Read(&record_length)) { std::vector<uint8_t> buffer(record_length); if (Read(buffer.size(), buffer.data())) { - *out = RawRecordRdata(std::move(buffer)); + ErrorOr<RawRecordRdata> rdata = + RawRecordRdata::TryCreate(std::move(buffer)); + if (rdata.is_error()) { + return false; + } + *out = std::move(rdata.value()); cursor.Commit(); return true; } @@ -189,11 +214,47 @@ bool MdnsReader::Read(TxtRecordRdata* out) { if (cursor.delta() != sizeof(record_length) + record_length) { return false; } - *out = TxtRecordRdata(std::move(texts)); + ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts)); + if (rdata.is_error()) { + return false; + } + *out = std::move(rdata.value()); cursor.Commit(); return true; } +bool MdnsReader::Read(NsecRecordRdata* out) { + OSP_DCHECK(out); + Cursor cursor(this); + + const uint8_t* start_position = current(); + uint16_t record_length; + DomainName next_record_name; + if (!Read(&record_length) || !Read(&next_record_name)) { + return false; + } + + // Calculate the next record name length. This may not be equal to the length + // of |next_record_name| due to domain name compression. + const int encoded_next_name_length = + current() - start_position - sizeof(record_length); + const int remaining_length = record_length - encoded_next_name_length; + if (remaining_length <= 0) { + // This means either the length is invalid or the NSEC record has no + // associated types. + return false; + } + + std::vector<DnsType> types; + if (Read(&types, remaining_length)) { + *out = NsecRecordRdata(std::move(next_record_name), std::move(types)); + cursor.Commit(); + return true; + } + + return false; +} + bool MdnsReader::Read(MdnsRecord* out) { OSP_DCHECK(out); Cursor cursor(this); @@ -204,9 +265,14 @@ bool MdnsReader::Read(MdnsRecord* out) { Rdata rdata; if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) && Read(static_cast<DnsType>(type), &rdata)) { - *out = MdnsRecord(std::move(name), static_cast<DnsType>(type), - GetDnsClass(rrclass), GetRecordType(rrclass), - std::chrono::seconds(ttl), std::move(rdata)); + ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate( + std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass), + GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata)); + if (record.is_error()) { + return false; + } + *out = std::move(record.value()); + cursor.Commit(); return true; } @@ -220,8 +286,14 @@ bool MdnsReader::Read(MdnsQuestion* out) { uint16_t type; uint16_t rrclass; if (Read(&name) && Read(&type) && Read(&rrclass)) { - *out = MdnsQuestion(std::move(name), static_cast<DnsType>(type), - GetDnsClass(rrclass), GetResponseType(rrclass)); + ErrorOr<MdnsQuestion> question = + MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type), + GetDnsClass(rrclass), GetResponseType(rrclass)); + if (question.is_error()) { + return false; + } + *out = std::move(question.value()); + cursor.Commit(); return true; } @@ -244,8 +316,18 @@ bool MdnsReader::Read(MdnsMessage* out) { // One way to do this is to change the method signature to return // ErrorOr<MdnsMessage> and return different error codes for failure to read // and for messages that were read successfully but are non-conforming. - *out = MdnsMessage(header.id, GetMessageType(header.flags), questions, - answers, authority_records, additional_records); + ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate( + header.id, GetMessageType(header.flags), questions, answers, + authority_records, additional_records); + if (message.is_error()) { + return false; + } + *out = std::move(message.value()); + + if (IsMessageTruncated(header.flags)) { + out->set_truncated(); + } + cursor.Commit(); return true; } @@ -278,7 +360,11 @@ bool MdnsReader::Read(DnsType type, Rdata* out) { return Read<PtrRecordRdata>(out); case DnsType::kTXT: return Read<TxtRecordRdata>(out); + case DnsType::kNSEC: + return Read<NsecRecordRdata>(out); default: + OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(), + type) == kSupportedDnsTypes.end()); return Read<RawRecordRdata>(out); } } @@ -295,5 +381,68 @@ bool MdnsReader::Read(Header* out) { return false; } +bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) { + OSP_DCHECK(out); + Cursor cursor(this); + + // Continue reading bitmaps until the entire input is read. If we have gone + // past the end of the record, it's malformed input so fail. + *out = std::vector<DnsType>(); + int processed_bytes = 0; + while (processed_bytes < remaining_size) { + NsecBitMapField bitmap; + if (!Read(&bitmap)) { + return false; + } + + processed_bytes += bitmap.bitmap_length + 2; + if (processed_bytes > remaining_size) { + return false; + } + + // The ith bit of the bitmap represents DnsType with value i, shifted + // a multiple of 0x100 according to the window. + for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) { + int current_byte = i / 8; + uint8_t bitmask = 0x80 >> i % 8; + + // If this bit flag represents a type we support, add it to the vector. + // Else, we won't be able to use it later on in the code anyway, so drop + // it. + DnsType type; + uint16_t type_index = i | (bitmap.window_block << 8); + if ((bitmap.bitmap[current_byte] & bitmask) && + TryParseDnsType(type_index, &type)) { + out->push_back(type); + } + } + } + + cursor.Commit(); + return true; +} + +bool MdnsReader::Read(NsecBitMapField* out) { + OSP_DCHECK(out); + Cursor cursor(this); + + // Read the window and bitmap length, then one byte for each byte called out + // by the length. + if (Read(&out->window_block) && Read(&out->bitmap_length)) { + if (out->bitmap_length == 0 || out->bitmap_length > 32) { + return false; + } + + out->bitmap = current(); + if (!Skip(out->bitmap_length)) { + return false; + } + cursor.Commit(); + return true; + } + + return false; +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h index defcf33abc6..ecf2aafacad 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h @@ -30,6 +30,7 @@ class MdnsReader : public BigEndianReader { bool Read(AAAARecordRdata* out); bool Read(PtrRecordRdata* out); bool Read(TxtRecordRdata* out); + bool Read(NsecRecordRdata* out); // Reads a DNS resource record with its RDATA. // The correct type of RDATA to be read is determined by the type // specified in the record. @@ -40,9 +41,17 @@ class MdnsReader : public BigEndianReader { bool Read(MdnsMessage* out); private: + struct NsecBitMapField { + uint8_t window_block; + uint8_t bitmap_length; + const uint8_t* bitmap; + }; + bool Read(IPAddress::Version version, IPAddress* out); bool Read(DnsType type, Rdata* out); bool Read(Header* out); + bool Read(std::vector<DnsType>* types, int remaining_length); + bool Read(NsecBitMapField* out); template <class ItemType> bool Read(uint16_t count, std::vector<ItemType>* out) { diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc new file mode 100644 index 00000000000..d2e2eb72cfb --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc @@ -0,0 +1,12 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_reader.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + openscreen::discovery::MdnsReader reader(data, size); + openscreen::discovery::MdnsMessage message; + reader.Read(&message); + return 0; +} diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc index 8f4259df45c..a802162d3f2 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc @@ -344,25 +344,64 @@ TEST(MdnsReaderTest, ReadTxtRecordRdata_EmptyEntries) { MakeTxtRecord({"foo=1", "bar=2"})); } -TEST(MdnsReaderTest, ReadTxtRecordRdata_TooShort) { +TEST(MdnsReaderTest, ReadNsecRecordRdata) { // clang-format off - constexpr uint8_t kTxtRecordRdata[] = { - 0x00, 0x0C, // RDLENGTH = 12 - 0x05, 'f', 'o', 'o', '=', '1', + constexpr uint8_t kExpectedRdata[] = { + 0x00, 0x20, // RDLENGTH = 32 + 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e', + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes have the following values: + // - kA = 1 (encoded in byte 1) + // kTXT = 16 (encoded in byte 3) + // - kSRV = 33 (encoded in byte 5) + // - kNSEC = 47 (encoded in 6 bytes) + // - The largest of these is 47, so 6 bytes are needed to encode this data. + // So the full encoded version is: + // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001 + // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 | + 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01 }; // clang-format on - TestReadEntryFails<TxtRecordRdata>(kTxtRecordRdata, sizeof(kTxtRecordRdata)); + TestReadEntrySucceeds( + kExpectedRdata, sizeof(kExpectedRdata), + NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA, + DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC)); } -TEST(MdnsReaderTest, ReadTxtRecordRdata_WrongLength) { +TEST(MdnsReaderTest, ReadNsecRecordRdata_TooShort) { // clang-format off - constexpr uint8_t kTxtRecordRdata[] = { - 0x00, 0x0F, // Wrong length specified - 0x05, 'f', 'o', 'o', '=', '1', - 0x05, 'b', 'a', 'r', '=', '2', + constexpr uint8_t kNsecRecordRdata[] = { + 0x00, 0x20, // RDLENGTH = 32 + 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e', + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x06, 0x40, 0x00 + }; + // clang-format on + TestReadEntryFails<NsecRecordRdata>(kNsecRecordRdata, + sizeof(kNsecRecordRdata)); +} + +TEST(MdnsReaderTest, ReadNsecRecordRdata_WrongLength) { + // clang-format off + constexpr uint8_t kNsecRecordRdata[] = { + 0x00, 0x21, // RDLENGTH = 33 + 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e', + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01 }; // clang-format on - TestReadEntryFails<TxtRecordRdata>(kTxtRecordRdata, sizeof(kTxtRecordRdata)); + TestReadEntryFails<NsecRecordRdata>(kNsecRecordRdata, + sizeof(kNsecRecordRdata)); } TEST(MdnsReaderTest, ReadMdnsRecord_ARecordRdata) { diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc index 9245b8c10a3..8dc9fdd1fa8 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc @@ -7,19 +7,19 @@ #include "discovery/mdns/mdns_reader.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace discovery { -MdnsReceiver::MdnsReceiver(UdpSocket* socket) : socket_(socket) { - OSP_DCHECK(socket_); -} +MdnsReceiver::ResponseClient::~ResponseClient() = default; + +MdnsReceiver::MdnsReceiver() = default; MdnsReceiver::~MdnsReceiver() { if (state_ == State::kRunning) { Stop(); } + + OSP_DCHECK(response_clients_.empty()); } void MdnsReceiver::SetQueryCallback( @@ -30,13 +30,20 @@ void MdnsReceiver::SetQueryCallback( query_callback_ = callback; } -void MdnsReceiver::SetResponseCallback( - std::function<void(const MdnsMessage&)> callback) { - // This check verifies that either new or stored callback has a target. It - // will fail in case multiple objects try to set or clear the callback. - OSP_DCHECK(static_cast<bool>(response_callback_) != - static_cast<bool>(callback)); - response_callback_ = callback; +void MdnsReceiver::AddResponseCallback(ResponseClient* callback) { + auto it = + std::find(response_clients_.begin(), response_clients_.end(), callback); + OSP_DCHECK(it == response_clients_.end()); + + response_clients_.push_back(callback); +} + +void MdnsReceiver::RemoveResponseCallback(ResponseClient* callback) { + auto it = + std::find(response_clients_.begin(), response_clients_.end(), callback); + OSP_DCHECK(it != response_clients_.end()); + + response_clients_.erase(it); } void MdnsReceiver::Start() { @@ -55,7 +62,7 @@ void MdnsReceiver::OnRead(UdpSocket* socket, UdpPacket packet = std::move(packet_or_error.value()); - TRACE_SCOPED(TraceCategory::mDNS, "MdnsReceiver::OnRead"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsReceiver::OnRead"); MdnsReader reader(packet.data(), packet.size()); MdnsMessage message; if (!reader.Read(&message)) { @@ -63,25 +70,20 @@ void MdnsReceiver::OnRead(UdpSocket* socket, } if (message.type() == MessageType::Response) { - if (response_callback_) { - response_callback_(message); + for (ResponseClient* client : response_clients_) { + client->OnMessageReceived(message); + } + if (response_clients_.empty()) { + OSP_DVLOG << "Response message dropped. No response client registered..."; } } else { if (query_callback_) { query_callback_(message, packet.source()); + } else { + OSP_DVLOG << "Query message dropped. No query client registered..."; } } } -void MdnsReceiver::OnError(UdpSocket* socket, Error error) { - // This method should never be called for MdnsReciever. - OSP_UNIMPLEMENTED(); -} - -void MdnsReceiver::OnSendError(UdpSocket* socket, Error error) { - // This method should never be called for MdnsReciever. - OSP_UNIMPLEMENTED(); -} - } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h index ca02543e35b..64e1a93da6a 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h @@ -5,6 +5,8 @@ #ifndef DISCOVERY_MDNS_MDNS_RECEIVER_H_ #define DISCOVERY_MDNS_MDNS_RECEIVER_H_ +#include <functional> + #include "platform/api/udp_socket.h" #include "platform/base/error.h" #include "platform/base/udp_packet.h" @@ -14,24 +16,29 @@ namespace discovery { class MdnsMessage; -class MdnsReceiver : openscreen::platform::UdpSocket::Client { +class MdnsReceiver { public: - using UdpPacket = openscreen::platform::UdpPacket; - using UdpSocket = openscreen::platform::UdpSocket; + class ResponseClient { + public: + virtual ~ResponseClient(); + + virtual void OnMessageReceived(const MdnsMessage& message) = 0; + }; // MdnsReceiver does not own |socket| and |delegate| // and expects that the lifetime of these objects exceeds the lifetime of // MdnsReceiver. - explicit MdnsReceiver(UdpSocket* socket); + MdnsReceiver(); MdnsReceiver(const MdnsReceiver& other) = delete; MdnsReceiver(MdnsReceiver&& other) noexcept = delete; MdnsReceiver& operator=(const MdnsReceiver& other) = delete; MdnsReceiver& operator=(MdnsReceiver&& other) noexcept = delete; - ~MdnsReceiver() override; + ~MdnsReceiver(); void SetQueryCallback( std::function<void(const MdnsMessage&, const IPEndpoint& src)> callback); - void SetResponseCallback(std::function<void(const MdnsMessage&)> callback); + void AddResponseCallback(ResponseClient* callback); + void RemoveResponseCallback(ResponseClient* callback); // The receiver can be started and stopped multiple times. // Start and Stop are both synchronous calls. When MdnsReceiver has not yet @@ -40,10 +47,7 @@ class MdnsReceiver : openscreen::platform::UdpSocket::Client { void Start(); void Stop(); - // UdpSocket::Client overrides. - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; - void OnError(UdpSocket* socket, Error error) override; - void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet); private: enum class State { @@ -51,11 +55,11 @@ class MdnsReceiver : openscreen::platform::UdpSocket::Client { kRunning, }; - UdpSocket* const socket_; std::function<void(const MdnsMessage&, const IPEndpoint& src)> query_callback_; - std::function<void(const MdnsMessage&)> response_callback_; State state_ = State::kStopped; + + std::vector<ResponseClient*> response_clients_; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc index 5ef2162584b..df7fae9e80d 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc @@ -4,6 +4,10 @@ #include "discovery/mdns/mdns_receiver.h" +#include <memory> +#include <utility> +#include <vector> + #include "discovery/mdns/mdns_records.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -13,13 +17,10 @@ namespace openscreen { namespace discovery { -using openscreen::platform::FakeUdpSocket; -using openscreen::platform::TaskRunner; -using openscreen::platform::UdpPacket; using testing::_; using testing::Return; -class MockMdnsReceiverDelegate { +class MockMdnsReceiverDelegate : public MdnsReceiver::ResponseClient { public: MOCK_METHOD(void, OnMessageReceived, (const MdnsMessage&)); }; @@ -42,10 +43,9 @@ TEST(MdnsReceiverTest, ReceiveQuery) { }; // clang-format on - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV4); + FakeUdpSocket socket; MockMdnsReceiverDelegate delegate; - MdnsReceiver receiver(socket_info.get()); + MdnsReceiver receiver; receiver.SetQueryCallback( [&delegate](const MdnsMessage& message, const IPEndpoint& endpoint) { delegate.OnMessageReceived(message); @@ -66,7 +66,7 @@ TEST(MdnsReceiverTest, ReceiveQuery) { // Imitate a call to OnRead from NetworkRunner by calling it manually here EXPECT_CALL(delegate, OnMessageReceived(message)).Times(1); - receiver.OnRead(socket_info.get(), std::move(packet)); + receiver.OnRead(&socket, std::move(packet)); receiver.Stop(); } @@ -91,19 +91,16 @@ TEST(MdnsReceiverTest, ReceiveResponse) { 0xac, 0x00, 0x00, 0x01, // 172.0.0.1 }; - constexpr uint8_t kIPv6AddressBytes[] = { - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + constexpr uint16_t kIPv6AddressHextets[] = { + 0xfe80, 0x0000, 0x0000, 0x0000, + 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; // clang-format on - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV6); + FakeUdpSocket socket; MockMdnsReceiverDelegate delegate; - MdnsReceiver receiver(socket_info.get()); - receiver.SetResponseCallback([&delegate](const MdnsMessage& message) { - delegate.OnMessageReceived(message); - }); + MdnsReceiver receiver; + receiver.AddResponseCallback(&delegate); receiver.Start(); MdnsRecord record(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, @@ -116,16 +113,17 @@ TEST(MdnsReceiverTest, ReceiveResponse) { packet.assign(kResponseBytes.data(), kResponseBytes.data() + kResponseBytes.size()); packet.set_source( - IPEndpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337}); + IPEndpoint{.address = IPAddress(kIPv6AddressHextets), .port = 31337}); packet.set_destination( IPEndpoint{.address = IPAddress(kDefaultMulticastGroupIPv6), .port = kDefaultMulticastPort}); // Imitate a call to OnRead from NetworkRunner by calling it manually here EXPECT_CALL(delegate, OnMessageReceived(message)).Times(1); - receiver.OnRead(socket_info.get(), std::move(packet)); + receiver.OnRead(&socket, std::move(packet)); receiver.Stop(); + receiver.RemoveResponseCallback(&delegate); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h index d5d340b9efa..c8c02d69797 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h @@ -5,6 +5,8 @@ #ifndef DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_ #define DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_ +#include "util/logging.h" + namespace openscreen { namespace discovery { @@ -23,6 +25,21 @@ class MdnsRecordChangedCallback { RecordChangedEvent event) = 0; }; +inline std::ostream& operator<<(std::ostream& output, + RecordChangedEvent event) { + switch (event) { + case RecordChangedEvent::kCreated: + return output << "Create"; + case RecordChangedEvent::kUpdated: + return output << "Update"; + case RecordChangedEvent::kExpired: + return output << "Expiry"; + } + + OSP_NOTREACHED(); + return output; +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc index 69f76f24e8c..a04c2694165 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc @@ -4,18 +4,23 @@ #include "discovery/mdns/mdns_records.h" -#include <atomic> #include <cctype> #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" +#include "discovery/mdns/mdns_writer.h" namespace openscreen { namespace discovery { namespace { +constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max(); + +constexpr size_t kMaxMessageFieldEntryCount = + std::numeric_limits<uint16_t>::max(); + inline int CompareIgnoreCase(const std::string& x, const std::string& y) { size_t i = 0; for (; i < x.size(); i++) { @@ -33,6 +38,54 @@ inline int CompareIgnoreCase(const std::string& x, const std::string& y) { return i == y.size() ? 0 : -1; } +template <typename RDataType> +bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) { + const RDataType& lhs_cast = absl::get<RDataType>(lhs); + const RDataType& rhs_cast = absl::get<RDataType>(rhs); + + size_t lhs_size = lhs_cast.MaxWireSize(); + size_t rhs_size = rhs_cast.MaxWireSize(); + size_t min_size = std::min(lhs_size, rhs_size); + + uint8_t lhs_bytes[lhs_size]; + uint8_t rhs_bytes[rhs_size]; + MdnsWriter lhs_writer(lhs_bytes, lhs_size); + MdnsWriter rhs_writer(rhs_bytes, rhs_size); + + lhs_writer.Write(lhs_cast); + rhs_writer.Write(rhs_cast); + for (size_t i = 0; i < min_size; i++) { + if (lhs_bytes[i] != rhs_bytes[i]) { + return lhs_bytes[i] > rhs_bytes[i]; + } + } + + if (lhs_size == rhs_size) { + return false; + } + + return lhs_size > rhs_size; +} + +bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) { + switch (type) { + case DnsType::kA: + return IsGreaterThan<ARecordRdata>(lhs, rhs); + case DnsType::kPTR: + return IsGreaterThan<PtrRecordRdata>(lhs, rhs); + case DnsType::kTXT: + return IsGreaterThan<TxtRecordRdata>(lhs, rhs); + case DnsType::kAAAA: + return IsGreaterThan<AAAARecordRdata>(lhs, rhs); + case DnsType::kSRV: + return IsGreaterThan<SrvRecordRdata>(lhs, rhs); + case DnsType::kNSEC: + return IsGreaterThan<NsecRecordRdata>(lhs, rhs); + default: + return IsGreaterThan<RawRecordRdata>(lhs, rhs); + } +} + } // namespace bool IsValidDomainLabel(absl::string_view label) { @@ -51,6 +104,9 @@ DomainName::DomainName(const std::vector<absl::string_view>& labels) DomainName::DomainName(std::initializer_list<absl::string_view> labels) : DomainName(labels.begin(), labels.end()) {} +DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size) + : max_wire_size_(max_wire_size), labels_(std::move(labels)) {} + DomainName::DomainName(const DomainName& other) = default; DomainName::DomainName(DomainName&& other) = default; @@ -112,12 +168,21 @@ size_t DomainName::MaxWireSize() const { return max_wire_size_; } +// static +ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) { + if (rdata.size() > kMaxRawRecordSize) { + return Error::Code::kIndexOutOfBounds; + } else { + return RawRecordRdata(std::move(rdata)); + } +} + RawRecordRdata::RawRecordRdata() = default; RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata) : rdata_(std::move(rdata)) { // Ensure RDATA length does not exceed the maximum allowed. - OSP_DCHECK(rdata_.size() <= std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize); } RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size) @@ -180,8 +245,10 @@ size_t SrvRecordRdata::MaxWireSize() const { ARecordRdata::ARecordRdata() = default; -ARecordRdata::ARecordRdata(IPAddress ipv4_address) - : ipv4_address_(std::move(ipv4_address)) { +ARecordRdata::ARecordRdata(IPAddress ipv4_address, + NetworkInterfaceIndex interface_index) + : ipv4_address_(std::move(ipv4_address)), + interface_index_(interface_index) { OSP_CHECK(ipv4_address_.IsV4()); } @@ -194,7 +261,8 @@ ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default; ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default; bool ARecordRdata::operator==(const ARecordRdata& rhs) const { - return ipv4_address_ == rhs.ipv4_address_; + return ipv4_address_ == rhs.ipv4_address_ && + interface_index_ == rhs.interface_index_; } bool ARecordRdata::operator!=(const ARecordRdata& rhs) const { @@ -208,8 +276,10 @@ size_t ARecordRdata::MaxWireSize() const { AAAARecordRdata::AAAARecordRdata() = default; -AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address) - : ipv6_address_(std::move(ipv6_address)) { +AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address, + NetworkInterfaceIndex interface_index) + : ipv6_address_(std::move(ipv6_address)), + interface_index_(interface_index) { OSP_CHECK(ipv6_address_.IsV6()); } @@ -223,7 +293,8 @@ AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) = AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default; bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const { - return ipv6_address_ == rhs.ipv6_address_; + return ipv6_address_ == rhs.ipv6_address_ && + interface_index_ == rhs.interface_index_; } bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const { @@ -261,23 +332,39 @@ size_t PtrRecordRdata::MaxWireSize() const { return sizeof(uint16_t) + ptr_domain_.MaxWireSize(); } -TxtRecordRdata::TxtRecordRdata() = default; - -TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) { +// static +ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) { + std::vector<std::string> str_texts; + size_t max_wire_size = 3; if (texts.size() > 0) { - texts_.reserve(texts.size()); + str_texts.reserve(texts.size()); // max_wire_size includes uint16_t record length field. - max_wire_size_ = sizeof(uint16_t); + max_wire_size = sizeof(uint16_t); for (const auto& text : texts) { - OSP_DCHECK(!text.empty()); - texts_.push_back( + if (text.empty()) { + return Error::Code::kParameterInvalid; + } + str_texts.push_back( std::string(reinterpret_cast<const char*>(text.data()), text.size())); // Include the length byte in the size calculation. - max_wire_size_ += text.size() + 1; + max_wire_size += text.size() + 1; } } + return TxtRecordRdata(std::move(str_texts), max_wire_size); } +TxtRecordRdata::TxtRecordRdata() = default; + +TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) { + ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts)); + OSP_DCHECK(rdata.is_value()); + *this = std::move(rdata.value()); +} + +TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts, + size_t max_wire_size) + : max_wire_size_(max_wire_size), texts_(std::move(texts)) {} + TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default; TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) = default; @@ -298,6 +385,91 @@ size_t TxtRecordRdata::MaxWireSize() const { return max_wire_size_; } +NsecRecordRdata::NsecRecordRdata() = default; + +NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name, + std::vector<DnsType> types) + : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) { + // Sort the types_ array for easier comparison later. + std::sort(types_.begin(), types_.end()); + + // Calculate the bitmaps as described in RFC 4034 Section 4.1.2. + std::vector<uint8_t> block_contents; + uint8_t current_block = 0; + for (auto type : types_) { + const uint16_t type_int = static_cast<uint16_t>(type); + const uint8_t block = static_cast<uint8_t>(type_int >> 8); + const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF); + const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits. + const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits. + + // If the block has changed, write the previous block's info and all of its + // contents to the |encoded_types_| vector. + if (block > current_block) { + if (!block_contents.empty()) { + encoded_types_.push_back(current_block); + encoded_types_.push_back(static_cast<uint8_t>(block_contents.size())); + encoded_types_.insert(encoded_types_.end(), block_contents.begin(), + block_contents.end()); + } + block_contents = std::vector<uint8_t>(); + current_block = block; + } + + // Make sure |block_contents| is large enough to hold the bit representing + // the new type , then set it. + if (block_contents.size() <= byte_bit_is_at) { + block_contents.insert(block_contents.end(), + byte_bit_is_at - block_contents.size() + 1, 0x00); + } + + block_contents[byte_bit_is_at] |= byte_mask; + } + + if (!block_contents.empty()) { + encoded_types_.push_back(current_block); + encoded_types_.push_back(static_cast<uint8_t>(block_contents.size())); + encoded_types_.insert(encoded_types_.end(), block_contents.begin(), + block_contents.end()); + } +} + +NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default; + +NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) = default; + +NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) = + default; + +NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default; + +bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const { + return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_; +} + +bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const { + return !(*this == rhs); +} + +size_t NsecRecordRdata::MaxWireSize() const { + return next_domain_name_.MaxWireSize() + encoded_types_.size(); +} + +// static +ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name, + DnsType dns_type, + DnsClass dns_class, + RecordType record_type, + std::chrono::seconds ttl, + Rdata rdata) { + if (!IsValidConfig(name, dns_type, ttl, rdata)) { + return Error::Code::kParameterInvalid; + } else { + return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl, + std::move(rdata)); + } +} + MdnsRecord::MdnsRecord() = default; MdnsRecord::MdnsRecord(DomainName name, @@ -312,19 +484,7 @@ MdnsRecord::MdnsRecord(DomainName name, record_type_(record_type), ttl_(ttl), rdata_(std::move(rdata)) { - OSP_DCHECK(!name_.empty()); - OSP_DCHECK_LE(ttl_.count(), std::numeric_limits<uint32_t>::max()); - OSP_DCHECK((dns_type == DnsType::kSRV && - absl::holds_alternative<SrvRecordRdata>(rdata_)) || - (dns_type == DnsType::kA && - absl::holds_alternative<ARecordRdata>(rdata_)) || - (dns_type == DnsType::kAAAA && - absl::holds_alternative<AAAARecordRdata>(rdata_)) || - (dns_type == DnsType::kPTR && - absl::holds_alternative<PtrRecordRdata>(rdata_)) || - (dns_type == DnsType::kTXT && - absl::holds_alternative<TxtRecordRdata>(rdata_)) || - absl::holds_alternative<RawRecordRdata>(rdata_)); + OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_)); } MdnsRecord::MdnsRecord(const MdnsRecord& other) = default; @@ -335,6 +495,27 @@ MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default; MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default; +// static +bool MdnsRecord::IsValidConfig(const DomainName& name, + DnsType dns_type, + std::chrono::seconds ttl, + const Rdata& rdata) { + return !name.empty() && ttl.count() <= std::numeric_limits<uint32_t>::max() && + ((dns_type == DnsType::kSRV && + absl::holds_alternative<SrvRecordRdata>(rdata)) || + (dns_type == DnsType::kA && + absl::holds_alternative<ARecordRdata>(rdata)) || + (dns_type == DnsType::kAAAA && + absl::holds_alternative<AAAARecordRdata>(rdata)) || + (dns_type == DnsType::kPTR && + absl::holds_alternative<PtrRecordRdata>(rdata)) || + (dns_type == DnsType::kTXT && + absl::holds_alternative<TxtRecordRdata>(rdata)) || + (dns_type == DnsType::kNSEC && + absl::holds_alternative<NsecRecordRdata>(rdata)) || + absl::holds_alternative<RawRecordRdata>(rdata)); +} + bool MdnsRecord::operator==(const MdnsRecord& rhs) const { return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ && record_type_ == rhs.record_type_ && ttl_ == rhs.ttl_ && @@ -345,12 +526,72 @@ bool MdnsRecord::operator!=(const MdnsRecord& rhs) const { return !(*this == rhs); } +bool MdnsRecord::operator>(const MdnsRecord& rhs) const { + // Returns the record which is lexicographically later. The determination of + // "lexicographically later" is performed by first comparing the record class, + // then the record type, then raw comparison of the binary content of the + // rdata without regard for meaning or structure. + if (dns_class() != rhs.dns_class()) { + return dns_class() > rhs.dns_class(); + } + + uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask; + uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask; + if (this_type != other_type) { + return this_type > other_type; + } + + return IsGreaterThan(dns_type(), rdata(), rhs.rdata()); +} + +bool MdnsRecord::operator<(const MdnsRecord& rhs) const { + return rhs > *this; +} + +bool MdnsRecord::operator<=(const MdnsRecord& rhs) const { + return !(*this > rhs); +} + +bool MdnsRecord::operator>=(const MdnsRecord& rhs) const { + return !(*this < rhs); +} + size_t MdnsRecord::MaxWireSize() const { auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); }; // NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8; } +MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) { + Rdata rdata; + DnsType type; + std::chrono::seconds ttl; + if (address.IsV4()) { + type = DnsType::kA; + rdata = ARecordRdata(address); + ttl = kARecordTtl; + } else { + type = DnsType::kAAAA; + rdata = AAAARecordRdata(address); + ttl = kAAAARecordTtl; + } + + return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique, + ttl, std::move(rdata)); +} + +// static +ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name, + DnsType dns_type, + DnsClass dns_class, + ResponseType response_type) { + if (name.empty()) { + return Error::Code::kParameterInvalid; + } + + return MdnsQuestion(std::move(name), dns_type, dns_class, response_type); +} + MdnsQuestion::MdnsQuestion(DomainName name, DnsType dns_type, DnsClass dns_class, @@ -376,6 +617,26 @@ size_t MdnsQuestion::MaxWireSize() const { return name_.MaxWireSize() + 4; } +// static +ErrorOr<MdnsMessage> MdnsMessage::TryCreate( + uint16_t id, + MessageType type, + std::vector<MdnsQuestion> questions, + std::vector<MdnsRecord> answers, + std::vector<MdnsRecord> authority_records, + std::vector<MdnsRecord> additional_records) { + if (questions.size() >= kMaxMessageFieldEntryCount || + answers.size() >= kMaxMessageFieldEntryCount || + authority_records.size() >= kMaxMessageFieldEntryCount || + additional_records.size() >= kMaxMessageFieldEntryCount) { + return Error::Code::kParameterInvalid; + } + + return MdnsMessage(id, type, std::move(questions), std::move(answers), + std::move(authority_records), + std::move(additional_records)); +} + MdnsMessage::MdnsMessage(uint16_t id, MessageType type) : id_(id), type_(type) {} @@ -391,10 +652,10 @@ MdnsMessage::MdnsMessage(uint16_t id, answers_(std::move(answers)), authority_records_(std::move(authority_records)), additional_records_(std::move(additional_records)) { - OSP_DCHECK(questions_.size() < std::numeric_limits<uint16_t>::max()); - OSP_DCHECK(answers_.size() < std::numeric_limits<uint16_t>::max()); - OSP_DCHECK(authority_records_.size() < std::numeric_limits<uint16_t>::max()); - OSP_DCHECK(additional_records_.size() < std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount); + OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount); + OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount); + OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount); for (const MdnsQuestion& question : questions_) { max_wire_size_ += question.MaxWireSize(); @@ -421,36 +682,62 @@ bool MdnsMessage::operator!=(const MdnsMessage& rhs) const { return !(*this == rhs); } +bool MdnsMessage::IsProbeQuery() const { + // A message is a probe query if it contains records in the authority section + // which answer the question being asked. + if (questions().empty() || authority_records().empty()) { + return false; + } + + for (const MdnsQuestion& question : questions_) { + for (const MdnsRecord& record : authority_records_) { + if (question.name() == record.name() && + ((question.dns_type() == record.dns_type()) || + (question.dns_type() == DnsType::kANY)) && + ((question.dns_class() == record.dns_class()) || + (question.dns_class() == DnsClass::kANY))) { + return true; + } + } + } + + return false; +} + size_t MdnsMessage::MaxWireSize() const { return max_wire_size_; } void MdnsMessage::AddQuestion(MdnsQuestion question) { - OSP_DCHECK(questions_.size() < std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount); max_wire_size_ += question.MaxWireSize(); questions_.emplace_back(std::move(question)); } void MdnsMessage::AddAnswer(MdnsRecord record) { - OSP_DCHECK(answers_.size() < std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount); max_wire_size_ += record.MaxWireSize(); answers_.emplace_back(std::move(record)); } void MdnsMessage::AddAuthorityRecord(MdnsRecord record) { - OSP_DCHECK(authority_records_.size() < std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount); max_wire_size_ += record.MaxWireSize(); authority_records_.emplace_back(std::move(record)); } void MdnsMessage::AddAdditionalRecord(MdnsRecord record) { - OSP_DCHECK(additional_records_.size() < std::numeric_limits<uint16_t>::max()); + OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount); max_wire_size_ += record.MaxWireSize(); additional_records_.emplace_back(std::move(record)); } +bool MdnsMessage::CanAddRecord(const MdnsRecord& record) { + return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize; +} + uint16_t CreateMessageId() { - static std::atomic<uint16_t> id(0); + static uint16_t id(0); return id++; } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h index d81d073e754..9cacace6cb9 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h @@ -5,15 +5,19 @@ #ifndef DISCOVERY_MDNS_MDNS_RECORDS_H_ #define DISCOVERY_MDNS_MDNS_RECORDS_H_ -#include <chrono> +#include <algorithm> +#include <chrono> // NOLINT #include <functional> #include <initializer_list> #include <string> +#include <utility> #include <vector> #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "discovery/mdns/public/mdns_constants.h" +#include "platform/base/error.h" +#include "platform/base/interface_info.h" #include "platform/base/ip_address.h" #include "util/logging.h" @@ -29,15 +33,31 @@ class DomainName { DomainName(); template <typename IteratorType> - DomainName(IteratorType first, IteratorType last) { - labels_.reserve(std::distance(first, last)); + static ErrorOr<DomainName> TryCreate(IteratorType first, IteratorType last) { + std::vector<std::string> labels; + size_t max_wire_size = 1; + labels.reserve(std::distance(first, last)); for (IteratorType entry = first; entry != last; ++entry) { - OSP_DCHECK(IsValidDomainLabel(*entry)); - labels_.emplace_back(*entry); + if (!IsValidDomainLabel(*entry)) { + return Error::Code::kParameterInvalid; + } + labels.emplace_back(*entry); // Include the length byte in the size calculation. - max_wire_size_ += entry->size() + 1; + max_wire_size += entry->size() + 1; + } + + if (max_wire_size > kMaxDomainNameLength) { + return Error::Code::kIndexOutOfBounds; + } else { + return DomainName(std::move(labels), max_wire_size); } - OSP_DCHECK(max_wire_size_ <= kMaxDomainNameLength); + } + + template <typename IteratorType> + DomainName(IteratorType first, IteratorType last) { + ErrorOr<DomainName> domain = TryCreate(first, last); + OSP_DCHECK(domain.is_value()); + *this = std::move(domain.value()); } explicit DomainName(std::vector<std::string> labels); explicit DomainName(const std::vector<absl::string_view>& labels); @@ -70,6 +90,8 @@ class DomainName { } private: + DomainName(std::vector<std::string> labels, size_t max_wire_size); + // max_wire_size_ starts at 1 for the terminating character length. size_t max_wire_size_ = 1; std::vector<std::string> labels_; @@ -80,6 +102,8 @@ class DomainName { // distinguish a raw record type that we do not know the identity of. class RawRecordRdata { public: + static ErrorOr<RawRecordRdata> TryCreate(std::vector<uint8_t> rdata); + RawRecordRdata(); explicit RawRecordRdata(std::vector<uint8_t> rdata); RawRecordRdata(const uint8_t* begin, size_t size); @@ -148,7 +172,8 @@ class SrvRecordRdata { class ARecordRdata { public: ARecordRdata(); - explicit ARecordRdata(IPAddress ipv4_address); + explicit ARecordRdata(IPAddress ipv4_address, + NetworkInterfaceIndex interface_index = 0); ARecordRdata(const ARecordRdata& other); ARecordRdata(ARecordRdata&& other); @@ -159,6 +184,7 @@ class ARecordRdata { size_t MaxWireSize() const; const IPAddress& ipv4_address() const { return ipv4_address_; } + NetworkInterfaceIndex interface_index() const { return interface_index_; } template <typename H> friend H AbslHashValue(H h, const ARecordRdata& rdata) { @@ -167,6 +193,7 @@ class ARecordRdata { private: IPAddress ipv4_address_{0, 0, 0, 0}; + NetworkInterfaceIndex interface_index_; }; // AAAA Record format (http://www.ietf.org/rfc/rfc1035.txt): @@ -174,7 +201,8 @@ class ARecordRdata { class AAAARecordRdata { public: AAAARecordRdata(); - explicit AAAARecordRdata(IPAddress ipv6_address); + explicit AAAARecordRdata(IPAddress ipv6_address, + NetworkInterfaceIndex interface_index = 0); AAAARecordRdata(const AAAARecordRdata& other); AAAARecordRdata(AAAARecordRdata&& other); @@ -185,6 +213,7 @@ class AAAARecordRdata { size_t MaxWireSize() const; const IPAddress& ipv6_address() const { return ipv6_address_; } + NetworkInterfaceIndex interface_index() const { return interface_index_; } template <typename H> friend H AbslHashValue(H h, const AAAARecordRdata& rdata) { @@ -192,7 +221,9 @@ class AAAARecordRdata { } private: - IPAddress ipv6_address_{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + IPAddress ipv6_address_{0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000}; + NetworkInterfaceIndex interface_index_; }; // PTR record format (http://www.ietf.org/rfc/rfc1035.txt): @@ -230,6 +261,9 @@ class PtrRecordRdata { class TxtRecordRdata { public: using Entry = std::vector<uint8_t>; + + static ErrorOr<TxtRecordRdata> TryCreate(std::vector<Entry> texts); + TxtRecordRdata(); explicit TxtRecordRdata(std::vector<Entry> texts); TxtRecordRdata(const TxtRecordRdata& other); @@ -250,6 +284,8 @@ class TxtRecordRdata { } private: + TxtRecordRdata(std::vector<std::string> texts, size_t max_wire_size); + // max_wire_size_ is at least 3, uint16_t record length and at the // minimum a NULL byte character string is present. size_t max_wire_size_ = 3; @@ -258,12 +294,71 @@ class TxtRecordRdata { std::vector<std::string> texts_; }; +// NSEC record format (https://tools.ietf.org/html/rfc4034#section-4). +// In mDNS, this record type is used for representing negative responses to +// queries. +// +// next_domain_name: The next domain to process. In mDNS, this value is expected +// to match the record-level domain name in a negative response. +// +// An example of how the |types_| vector is serialized is as follows: +// When encoding the following DNS types: +// - A (value 1) +// - MX (value 15) +// - RRSIG (value 46) +// - NSEC (value 47) +// - TYPE1234 (value 1234) +// The result would be: +// 0x00 0x06 0x40 0x01 0x00 0x00 0x00 0x03 +// 0x04 0x1b 0x00 0x00 0x00 0x00 0x00 0x00 +// 0x00 0x00 0x00 0x00 0x00 0x00 0x00 0x00 +// 0x00 0x00 0x00 0x00 0x00 0x00 0x00 0x00 +// 0x00 0x00 0x00 0x00 0x20 +class NsecRecordRdata { + public: + NsecRecordRdata(); + + // Constructor that takes an arbitrary number of DnsType parameters. + // NOTE: If `types...` provide a valid set of parameters for an + // std::vector<DnsType> ctor call, this will compile. Do not use this ctor + // except to provide multiple DnsType parameters. + template <typename... Types> + NsecRecordRdata(DomainName next_domain_name, Types... types) + : NsecRecordRdata(std::move(next_domain_name), + std::vector<DnsType>{types...}) {} + NsecRecordRdata(DomainName next_domain_name, std::vector<DnsType> types); + NsecRecordRdata(const NsecRecordRdata& other); + NsecRecordRdata(NsecRecordRdata&& other); + + NsecRecordRdata& operator=(const NsecRecordRdata& rhs); + NsecRecordRdata& operator=(NsecRecordRdata&& rhs); + bool operator==(const NsecRecordRdata& rhs) const; + bool operator!=(const NsecRecordRdata& rhs) const; + + size_t MaxWireSize() const; + + const DomainName& next_domain_name() const { return next_domain_name_; } + const std::vector<DnsType>& types() const { return types_; } + const std::vector<uint8_t>& encoded_types() const { return encoded_types_; } + + template <typename H> + friend H AbslHashValue(H h, const NsecRecordRdata& rdata) { + return H::combine(std::move(h), rdata.types_, rdata.next_domain_name_); + } + + private: + std::vector<uint8_t> encoded_types_; + std::vector<DnsType> types_; + DomainName next_domain_name_; +}; + using Rdata = absl::variant<RawRecordRdata, SrvRecordRdata, ARecordRdata, AAAARecordRdata, PtrRecordRdata, - TxtRecordRdata>; + TxtRecordRdata, + NsecRecordRdata>; // Resource record top level format (http://www.ietf.org/rfc/rfc1035.txt): // name: the name of the node to which this resource record pertains. @@ -276,6 +371,13 @@ class MdnsRecord { public: using ConstRef = std::reference_wrapper<const MdnsRecord>; + static ErrorOr<MdnsRecord> TryCreate(DomainName name, + DnsType dns_type, + DnsClass dns_class, + RecordType record_type, + std::chrono::seconds ttl, + Rdata rdata); + MdnsRecord(); MdnsRecord(DomainName name, DnsType dns_type, @@ -290,6 +392,10 @@ class MdnsRecord { MdnsRecord& operator=(MdnsRecord&& rhs); bool operator==(const MdnsRecord& other) const; bool operator!=(const MdnsRecord& other) const; + bool operator<(const MdnsRecord& other) const; + bool operator>(const MdnsRecord& other) const; + bool operator<=(const MdnsRecord& other) const; + bool operator>=(const MdnsRecord& other) const; size_t MaxWireSize() const; const DomainName& name() const { return name_; } @@ -307,6 +413,11 @@ class MdnsRecord { } private: + static bool IsValidConfig(const DomainName& name, + DnsType dns_type, + std::chrono::seconds ttl, + const Rdata& rdata); + DomainName name_; DnsType dns_type_ = static_cast<DnsType>(0); DnsClass dns_class_ = static_cast<DnsClass>(0); @@ -317,12 +428,20 @@ class MdnsRecord { Rdata rdata_; }; +// Creates an A or AAAA record as appropriate for the provided parameters. +MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address); + // Question top level format (http://www.ietf.org/rfc/rfc1035.txt): // name: a domain name which identifies the target resource set. // type: 2 bytes network-order RR TYPE code. // class: 2 bytes network-order RR CLASS code. class MdnsQuestion { public: + static ErrorOr<MdnsQuestion> TryCreate(DomainName name, + DnsType dns_type, + DnsClass dns_class, + ResponseType response_type); + MdnsQuestion() = default; MdnsQuestion(DomainName name, DnsType dns_type, @@ -365,6 +484,14 @@ class MdnsQuestion { // query class MdnsMessage { public: + static ErrorOr<MdnsMessage> TryCreate( + uint16_t id, + MessageType type, + std::vector<MdnsQuestion> questions, + std::vector<MdnsRecord> answers, + std::vector<MdnsRecord> authority_records, + std::vector<MdnsRecord> additional_records); + MdnsMessage() = default; // Constructs a message with ID, flags and empty question, answer, authority // and additional record collections. @@ -384,9 +511,23 @@ class MdnsMessage { void AddAuthorityRecord(MdnsRecord record); void AddAdditionalRecord(MdnsRecord record); + // Returns false if adding a new record would push the size of this message + // beyond kMaxMulticastMessageSize, and true otherwise. + bool CanAddRecord(const MdnsRecord& record); + + // Sets the truncated bit (TC), as specified in RFC 1035 Section 4.1.1. + void set_truncated() { is_truncated_ = true; } + + // Returns true if the provided message is an mDNS probe query as described in + // RFC 6762 section 8.1. Specifically, it examines whether any question in + // the 'questions' section is a query for which answers are present in the + // 'authority records' section of the same message. + bool IsProbeQuery() const; + size_t MaxWireSize() const; uint16_t id() const { return id_; } MessageType type() const { return type_; } + bool is_truncated() const { return is_truncated_; } const std::vector<MdnsQuestion>& questions() const { return questions_; } const std::vector<MdnsRecord>& answers() const { return answers_; } const std::vector<MdnsRecord>& authority_records() const { @@ -407,6 +548,7 @@ class MdnsMessage { // The mDNS header is 12 bytes long size_t max_wire_size_ = sizeof(Header); uint16_t id_ = 0; + bool is_truncated_ = false; MessageType type_ = MessageType::Query; std::vector<MdnsQuestion> questions_; std::vector<MdnsRecord> answers_; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc index 1c8680165ef..0fab37ded24 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc @@ -215,38 +215,34 @@ TEST(MdnsARecordRdataTest, CopyAndMove) { } TEST(MdnsAAAARecordRdataTest, Construct) { - constexpr uint8_t kIPv6AddressBytes1[] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + constexpr uint16_t kIPv6AddressHextets1[] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, }; - constexpr uint8_t kIPv6AddressBytes2[] = { - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + constexpr uint16_t kIPv6AddressHextets2[] = { + 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; - IPAddress address1(kIPv6AddressBytes1); + IPAddress address1(kIPv6AddressHextets1); AAAARecordRdata rdata1; EXPECT_EQ(rdata1.MaxWireSize(), UINT64_C(18)); EXPECT_EQ(rdata1.ipv6_address(), address1); - IPAddress address2(kIPv6AddressBytes2); + IPAddress address2(kIPv6AddressHextets2); AAAARecordRdata rdata2(address2); EXPECT_EQ(rdata2.MaxWireSize(), UINT64_C(18)); EXPECT_EQ(rdata2.ipv6_address(), address2); } TEST(MdnsAAAARecordRdataTest, Compare) { - constexpr uint8_t kIPv6AddressBytes1[] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + constexpr uint16_t kIPv6AddressHextets1[] = { + 0x0001, 0x0203, 0x0405, 0x0607, 0x0809, 0x0A0B, 0x0C0D, 0x0E0F, }; - constexpr uint8_t kIPv6AddressBytes2[] = { - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + constexpr uint16_t kIPv6AddressHextets2[] = { + 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; - IPAddress address1(kIPv6AddressBytes1); - IPAddress address2(kIPv6AddressBytes2); + IPAddress address1(kIPv6AddressHextets1); + IPAddress address2(kIPv6AddressHextets2); AAAARecordRdata rdata1(address1); AAAARecordRdata rdata2(address1); AAAARecordRdata rdata3(address2); @@ -256,11 +252,10 @@ TEST(MdnsAAAARecordRdataTest, Compare) { } TEST(MdnsAAAARecordRdataTest, CopyAndMove) { - constexpr uint8_t kIPv6AddressBytes[] = { - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + constexpr uint16_t kIPv6AddressHextets[] = { + 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; - TestCopyAndMove(AAAARecordRdata(IPAddress(kIPv6AddressBytes))); + TestCopyAndMove(AAAARecordRdata(IPAddress(kIPv6AddressHextets))); } TEST(MdnsPtrRecordRdataTest, Construct) { @@ -311,6 +306,135 @@ TEST(MdnsTxtRecordRdataTest, CopyAndMove) { TestCopyAndMove(MakeTxtRecord({"foo=1", "bar=2"})); } +TEST(MdnsNsecRecordRdataTest, Construct) { + const DomainName domain{"testing", "local"}; + NsecRecordRdata rdata(domain); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize()); + EXPECT_EQ(rdata.next_domain_name(), domain); + + rdata = NsecRecordRdata(domain, DnsType::kA); + // It takes 3 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsType kA = 1 (encoded in byte 1) + // So the full encoded version is: + // 00000000 00000001 01000000 + // |window| | size | | 0-7 | + // For a total of 3 bytes. + EXPECT_EQ(rdata.encoded_types(), (std::vector<uint8_t>{0x00, 0x01, 0x40})); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 3); + EXPECT_EQ(rdata.next_domain_name(), domain); + + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes kTXT = 16 (encoded in byte 3) + // So the full encoded version is: + // 00000000 00000011 00000000 00000000 10000000 + // |window| | size | | 0-7 | | 8-15 | |16-23 | + // For a total of 5 bytes. + rdata = NsecRecordRdata(domain, DnsType::kTXT); + EXPECT_EQ(rdata.encoded_types(), + (std::vector<uint8_t>{0x00, 0x03, 0x00, 0x00, 0x80})); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 5); + EXPECT_EQ(rdata.next_domain_name(), domain); + + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes kSRV = 33 (encoded in byte 5) + // So the full encoded version is: + // 00000000 00000101 00000000 00000000 00000000 00000000 01000000 + // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | + // For a total of 7 bytes. + rdata = NsecRecordRdata(domain, DnsType::kSRV); + EXPECT_EQ(rdata.encoded_types(), + (std::vector<uint8_t>{0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x40})); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 7); + EXPECT_EQ(rdata.next_domain_name(), domain); + + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes kNSEC = 47 + // So the full encoded version is: + // 00000000 00000110 00000000 00000000 00000000 00000000 0000000 00000001 + // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 | + // For a total of 8 bytes. + rdata = NsecRecordRdata(domain, DnsType::kNSEC); + EXPECT_EQ( + rdata.encoded_types(), + (std::vector<uint8_t>{0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 8); + EXPECT_EQ(rdata.next_domain_name(), domain); + + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes kNSEC = 255 + // So 32 bits are required for the bitfield, for a total of 34 bits. + rdata = NsecRecordRdata(domain, DnsType::kANY); + std::vector<uint8_t> results{0x00, 32}; + for (int i = 1; i < 32; i++) { + results.push_back(0x00); + } + results.push_back(0x01); + EXPECT_EQ(rdata.encoded_types(), results); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 34); + EXPECT_EQ(rdata.next_domain_name(), domain); + + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes have the following values: + // - kA = 1 (encoded in byte 1) + // kTXT = 16 (encoded in byte 3) + // - kSRV = 33 (encoded in byte 5) + // - kNSEC = 47 (encoded in 6 bytes) + // - The largest of these is 47, so 6 bytes are needed to encode this data. + // So the full encoded version is: + // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001 + // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 | + // For a total of 8 bytes. + rdata = NsecRecordRdata(domain, DnsType::kA, DnsType::kTXT, DnsType::kSRV, + DnsType::kNSEC); + EXPECT_EQ( + rdata.encoded_types(), + (std::vector<uint8_t>{0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01})); + EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 8); + EXPECT_EQ(rdata.next_domain_name(), domain); +} + +TEST(MdnsNsecRecordRdataTest, Compare) { + const DomainName domain{"testing", "local"}; + const NsecRecordRdata rdata1(domain, DnsType::kA, DnsType::kSRV); + const NsecRecordRdata rdata2(domain, DnsType::kSRV, DnsType::kA); + const NsecRecordRdata rdata3(domain, DnsType::kSRV, DnsType::kA, + DnsType::kAAAA); + const NsecRecordRdata rdata4(domain, DnsType::kSRV, DnsType::kAAAA); + + // Ensure equal Rdata values are evaluated as equal. + EXPECT_EQ(rdata1, rdata1); + EXPECT_EQ(rdata1, rdata2); + EXPECT_EQ(rdata2, rdata1); + + // Ensure different Rdata values are not. + EXPECT_NE(rdata1, rdata3); + EXPECT_NE(rdata1, rdata4); + EXPECT_NE(rdata3, rdata4); +} + +TEST(MdnsNsecRecordRdataTest, CopyAndMove) { + TestCopyAndMove(NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA, + DnsType::kSRV)); +} + TEST(MdnsRecordTest, Construct) { MdnsRecord record1; EXPECT_EQ(record1.MaxWireSize(), UINT64_C(11)); diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc index ce3dd4be510..969010aa64c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc @@ -4,6 +4,9 @@ #include "discovery/mdns/mdns_responder.h" +#include <utility> + +#include "discovery/mdns/mdns_probe_manager.h" #include "discovery/mdns/mdns_publisher.h" #include "discovery/mdns/mdns_querier.h" #include "discovery/mdns/mdns_random.h" @@ -13,27 +16,280 @@ namespace openscreen { namespace discovery { +namespace { + +const std::array<std::string, 3> kServiceEnumerationDomainLabels{ + "_services", "_dns-sd", "_udp"}; + +enum AddResult { kNonePresent = 0, kAdded, kAlreadyKnown }; + +std::chrono::seconds GetTtlForRecordType(DnsType type) { + switch (type) { + case DnsType::kA: + return kARecordTtl; + case DnsType::kAAAA: + return kAAAARecordTtl; + case DnsType::kPTR: + return kPtrRecordTtl; + case DnsType::kSRV: + return kSrvRecordTtl; + case DnsType::kTXT: + return kTXTRecordTtl; + case DnsType::kANY: + // If no records are present, re-querying should happen at the minimum + // of any record that might be retrieved at that time. + return kSrvRecordTtl; + default: + OSP_NOTREACHED(); + return std::chrono::seconds{0}; + } +} + +MdnsRecord CreateNsecRecord(DomainName target_name, + DnsType target_type, + DnsClass target_class) { + auto rdata = NsecRecordRdata(target_name, target_type); + std::chrono::seconds ttl = GetTtlForRecordType(target_type); + return MdnsRecord(std::move(target_name), DnsType::kNSEC, target_class, + RecordType::kUnique, ttl, std::move(rdata)); +} + +inline bool IsValidAdditionalRecordType(DnsType type) { + return type == DnsType::kSRV || type == DnsType::kTXT || + type == DnsType::kA || type == DnsType::kAAAA; +} + +AddResult AddRecords(std::function<void(MdnsRecord record)> add_func, + MdnsResponder::RecordHandler* record_handler, + const DomainName& domain, + const std::vector<MdnsRecord>& known_answers, + DnsType type, + DnsClass clazz, + bool add_negative_on_unknown) { + auto records = record_handler->GetRecords(domain, type, clazz); + if (records.empty()) { + if (add_negative_on_unknown) { + // TODO(rwkeane): Aggregate all NSEC records together into a single NSEC + // record to reduce traffic. + add_func(CreateNsecRecord(domain, type, clazz)); + } + return AddResult::kNonePresent; + } else { + bool added_any_records = false; + for (auto it = records.begin(); it != records.end(); it++) { + if (std::find(known_answers.begin(), known_answers.end(), *it) == + known_answers.end()) { + added_any_records = true; + add_func(std::move(*it)); + } + } + return added_any_records ? AddResult::kAdded : AddResult::kAlreadyKnown; + } +} + +inline AddResult AddAdditionalRecords( + MdnsMessage* message, + MdnsResponder::RecordHandler* record_handler, + const DomainName& domain, + const std::vector<MdnsRecord>& known_answers, + DnsType type, + DnsClass clazz, + bool add_negative_on_unknown) { + OSP_DCHECK(IsValidAdditionalRecordType(type)); + + auto add_func = [message](MdnsRecord record) { + message->AddAdditionalRecord(std::move(record)); + }; + return AddRecords(std::move(add_func), record_handler, domain, known_answers, + type, clazz, add_negative_on_unknown); +} + +inline AddResult AddResponseRecords( + MdnsMessage* message, + MdnsResponder::RecordHandler* record_handler, + const DomainName& domain, + const std::vector<MdnsRecord>& known_answers, + DnsType type, + DnsClass clazz, + bool add_negative_on_unknown) { + auto add_func = [message](MdnsRecord record) { + message->AddAnswer(std::move(record)); + }; + return AddRecords(std::move(add_func), record_handler, domain, known_answers, + type, clazz, add_negative_on_unknown); +} + +void ApplyQueryResults(MdnsMessage* message, + MdnsResponder::RecordHandler* record_handler, + const DomainName& domain, + const std::vector<MdnsRecord>& known_answers, + DnsType type, + DnsClass clazz, + bool is_exclusive_owner) { + OSP_DCHECK(type != DnsType::kNSEC); + + // All records matching the provided query which have been published by this + // host should be added to the response message per RFC 6762 section 6. If + // this host is the exclusive owner of the queried domain name, then a + // negative response NSEC record should be added in the case where the queried + // record does not exist, per RFC 6762 section 6.1. + if (AddResponseRecords(message, record_handler, domain, known_answers, type, + clazz, is_exclusive_owner) != AddResult::kAdded) { + return; + } + + // Per RFC 6763 section 12.1, when querying for a PTR record, all SRV records + // and TXT records named in the PTR record's rdata should be added to the + // messages additional records, as well as the address records of types A and + // AAAA associated with the added SRV records. Per RFC 6762 section 6.1, + // records with names matching those of reverse address mappings for PTR + // records may be added as negative response NSEC records if they do not + // exist. + if (type == DnsType::kPTR) { + // Add all SRV and TXT records to the additional records section. + for (const MdnsRecord& record : message->answers()) { + OSP_DCHECK(record.dns_type() == DnsType::kPTR); + + const DomainName& target = + absl::get<PtrRecordRdata>(record.rdata()).ptr_domain(); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kSRV, clazz, true); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kTXT, clazz, true); + } + + // Add A and AAAA records associated with an added SRV record to the + // additional records section. + const int max = message->additional_records().size(); + for (int i = 0; i < max; i++) { + if (message->additional_records()[i].dns_type() != DnsType::kSRV) { + continue; + } + + { + const MdnsRecord& srv_record = message->additional_records()[i]; + const DomainName& target = + absl::get<SrvRecordRdata>(srv_record.rdata()).target(); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kA, clazz, target == domain); + } + + // Must re-calculate the |srv_record|, |target| refs in case a resize of + // the additional_records() vector has invalidated them. + { + const MdnsRecord& srv_record = message->additional_records()[i]; + const DomainName& target = + absl::get<SrvRecordRdata>(srv_record.rdata()).target(); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kAAAA, clazz, target == domain); + } + } + } + + // Per RFC 6763 section 12.2, when querying for an SRV record, all address + // records of type A and AAAA should be added to the additional records + // section. Per RFC 6762 section 6.1, if these records are not present and + // their name and class match that which is being queried for, a negative + // response NSEC record may be added to show their non-existence. + else if (type == DnsType::kSRV) { + for (const auto& srv_record : message->answers()) { + OSP_DCHECK(srv_record.dns_type() == DnsType::kSRV); + + const DomainName& target = + absl::get<SrvRecordRdata>(srv_record.rdata()).target(); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kA, clazz, target == domain); + AddAdditionalRecords(message, record_handler, target, known_answers, + DnsType::kAAAA, clazz, target == domain); + } + } + + // Per RFC 6762 section 6.2, when querying for an address record of type A or + // AAAA, the record of the opposite type should be added to the additional + // records section if present. Else, a negative response NSEC record should be + // added to show its non-existence. + else if (type == DnsType::kA) { + AddAdditionalRecords(message, record_handler, domain, known_answers, + DnsType::kAAAA, clazz, true); + } else if (type == DnsType::kAAAA) { + AddAdditionalRecords(message, record_handler, domain, known_answers, + DnsType::kA, clazz, true); + } + + // The remaining supported records types are TXT, NSEC, and ANY. RFCs 6762 and + // 6763 do not recommend sending any records in the additional records section + // for queries of types TXT or ANY, and NSEC records are not supported for + // queries. +} + +// Determines if the provided query is a type enumeration query as described in +// RFC 6763 section 9. +bool IsServiceTypeEnumerationQuery(const MdnsQuestion& question) { + if (question.dns_type() != DnsType::kPTR) { + return false; + } + + if (question.name().labels().size() < + kServiceEnumerationDomainLabels.size()) { + return false; + } + + const auto question_it = question.name().labels().begin(); + return std::equal(question_it, + question_it + kServiceEnumerationDomainLabels.size(), + kServiceEnumerationDomainLabels.begin(), + kServiceEnumerationDomainLabels.end()); +} + +// Creates the expected response to a type enumeration query as described in RFC +// 6763 section 9. +void ApplyServiceTypeEnumerationResults( + MdnsMessage* message, + MdnsResponder::RecordHandler* record_handler, + const DomainName& name, + DnsClass clazz) { + if (name.labels().size() < kServiceEnumerationDomainLabels.size()) { + return; + } + + std::vector<MdnsRecord::ConstRef> records = + record_handler->GetPtrRecords(clazz); + + // skip "_services._dns-sd._udp." which was already checked for in above + // method and just use the domain. + const auto domain_it = + name.labels().begin() + kServiceEnumerationDomainLabels.size(); + for (const MdnsRecord& record : records) { + // Skip the 2 label service name in the PTR record's name. + const auto record_it = record.name().labels().begin() + 2; + if (std::equal(domain_it, name.labels().end(), record_it, + record.name().labels().end())) { + message->AddAnswer(MdnsRecord(name, DnsType::kPTR, record.dns_class(), + RecordType::kShared, record.ttl(), + PtrRecordRdata(record.name()))); + } + } +} + +} // namespace MdnsResponder::MdnsResponder(RecordHandler* record_handler, + MdnsProbeManager* ownership_handler, MdnsSender* sender, MdnsReceiver* receiver, - MdnsQuerier* querier, - platform::TaskRunner* task_runner, - platform::ClockNowFunctionPtr now_function, + TaskRunner* task_runner, MdnsRandom* random_delay) : record_handler_(record_handler), + ownership_handler_(ownership_handler), sender_(sender), receiver_(receiver), - querier_(querier), task_runner_(task_runner), - now_function_(now_function), random_delay_(random_delay) { OSP_DCHECK(record_handler_); + OSP_DCHECK(ownership_handler_); OSP_DCHECK(sender_); OSP_DCHECK(receiver_); - OSP_DCHECK(querier_); OSP_DCHECK(task_runner_); - OSP_DCHECK(now_function_); OSP_DCHECK(random_delay_); auto func = [this](const MdnsMessage& message, const IPEndpoint& src) { @@ -46,12 +302,117 @@ MdnsResponder::~MdnsResponder() { receiver_->SetQueryCallback(nullptr); } +MdnsResponder::RecordHandler::~RecordHandler() = default; + void MdnsResponder::OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(message.type() == MessageType::Query); - // TODO(rwkeane): implement responding to the query + if (message.questions().empty()) { + // TODO(rwkeane): Support multi-packet known answer suppression. + return; + } + + if (message.IsProbeQuery()) { + ownership_handler_->RespondToProbeQuery(message, src); + return; + } + + OSP_DVLOG << "Received mDNS Query with " << message.questions().size() + << " questions. Processing..."; + + const std::vector<MdnsRecord>& known_answers = message.answers(); + + for (const auto& question : message.questions()) { + OSP_DVLOG << "\tProcessing mDNS Query for domain: '" + << question.name().ToString() << "', type: '" + << question.dns_type() << "'"; + + // NSEC records should not be queried for. + if (question.dns_type() == DnsType::kNSEC) { + continue; + } + + // Only respond to queries for which one of the following is true: + // - This host is the sole owner of that domain. + // - A record corresponding to this question has been published. + // - The query is a service enumeration query. + const bool is_service_enumeration = IsServiceTypeEnumerationQuery(question); + const bool is_exclusive_owner = + ownership_handler_->IsDomainClaimed(question.name()); + if (!is_service_enumeration && !is_exclusive_owner && + !record_handler_->HasRecords(question.name(), question.dns_type(), + question.dns_class())) { + OSP_DVLOG << "\tmDNS Query processed and no relevant records found!"; + continue; + } else if (is_service_enumeration) { + OSP_DVLOG << "\tmDNS Query is for service type enumeration!"; + } + + // Relevant records are published, so send them out using the response type + // dictated in the question. + std::function<void(const MdnsMessage&)> send_response; + if (question.response_type() == ResponseType::kMulticast) { + send_response = [this](const MdnsMessage& message) { + sender_->SendMulticast(message); + }; + } else { + OSP_DCHECK(question.response_type() == ResponseType::kUnicast); + send_response = [this, src](const MdnsMessage& message) { + sender_->SendMessage(message, src); + }; + } + + // If this host is the exclusive owner, respond immediately. Else, there may + // be network contention if all hosts respond simultaneously, so delay the + // response as dictated by RFC 6762. + if (is_exclusive_owner) { + SendResponse(question, known_answers, send_response, is_exclusive_owner); + } else { + const auto delay = random_delay_->GetSharedRecordResponseDelay(); + std::function<void()> response = [this, question, known_answers, + send_response, is_exclusive_owner]() { + SendResponse(question, known_answers, send_response, + is_exclusive_owner); + }; + task_runner_->PostTaskWithDelay(response, delay); + } + } +} + +void MdnsResponder::SendResponse( + const MdnsQuestion& question, + const std::vector<MdnsRecord>& known_answers, + std::function<void(const MdnsMessage&)> send_response, + bool is_exclusive_owner) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + MdnsMessage message(CreateMessageId(), MessageType::Response); + + if (IsServiceTypeEnumerationQuery(question)) { + // This is a special case defined in RFC 6763 section 9, so handle it + // separately. + ApplyServiceTypeEnumerationResults(&message, record_handler_, + question.name(), question.dns_class()); + } else { + // NOTE: The exclusive ownership of this record cannot change before this + // method is called. Exclusive ownership cannot be gained for a record which + // has previously been published, and if this host is the exclusive owner + // then this method will have been called without any delay on the task + // runner + ApplyQueryResults(&message, record_handler_, question.name(), known_answers, + question.dns_type(), question.dns_class(), + is_exclusive_owner); + } + + // Send the response only if it contains answers to the query. + if (!message.answers().empty()) { + OSP_DVLOG << "\tmDNS Query processed and response sent!"; + send_response(message); + } else { + OSP_DVLOG << "\tmDNS Query processed and no response sent!"; + } } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h index cde4518e3dc..8ac7d1f3d4c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h @@ -12,27 +12,35 @@ #include "platform/base/macros.h" namespace openscreen { -struct IPEndpoint; -namespace platform { +struct IPEndpoint; class TaskRunner; -} // namespace platform namespace discovery { class MdnsMessage; +class MdnsProbeManager; class MdnsRandom; class MdnsReceiver; class MdnsRecordChangedCallback; class MdnsSender; class MdnsQuerier; +// This class is responsible for responding to any incoming mDNS Queries +// received via the OnMessageReceived() method. When responding, the generated +// MdnsMessage will contain the requested record(s) in the answers section, or +// an NSEC record to specify that the requested record was not found in the case +// of a query with DnsType aside from ANY. In the case where records are found, +// the additional records field may be populated with additional records, as +// specified in RFCs 6762 and 6763. +// TODO(rwkeane): Handle known answers, and waiting when the truncated (TC) bit +// is set. class MdnsResponder { public: // Class to handle querying for existing records. class RecordHandler { - // Returns whether the provided name is exclusively owned by this endpoint. - virtual bool IsExclusiveOwner(const DomainName& name) = 0; + public: + virtual ~RecordHandler(); // Returns whether this service has one or more records matching the // provided name, type, and class. @@ -45,17 +53,18 @@ class MdnsResponder { virtual std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name, DnsType type, DnsClass clazz) = 0; + + // Enumerates all PTR records owned by this service. + virtual std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) = 0; }; - // |record_handler|, |sender|, |receiver|, |querier|, |task_runner|, and - // |random_delay| are expected to persist for the duration of this instance's - // lifetime. + // |record_handler|, |sender|, |receiver|, |task_runner|, and |random_delay| + // are expected to persist for the duration of this instance's lifetime. MdnsResponder(RecordHandler* record_handler, + MdnsProbeManager* ownership_handler, MdnsSender* sender, MdnsReceiver* receiver, - MdnsQuerier* querier, - platform::TaskRunner* task_runner, - platform::ClockNowFunctionPtr now_function, + TaskRunner* task_runner, MdnsRandom* random_delay); ~MdnsResponder(); @@ -64,13 +73,19 @@ class MdnsResponder { private: void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src); + void SendResponse(const MdnsQuestion& question, + const std::vector<MdnsRecord>& known_answers, + std::function<void(const MdnsMessage&)> send_response, + bool is_exclusive_owner); + RecordHandler* const record_handler_; + MdnsProbeManager* const ownership_handler_; MdnsSender* const sender_; MdnsReceiver* const receiver_; - MdnsQuerier* const querier_; - platform::TaskRunner* const task_runner_; - const platform::ClockNowFunctionPtr now_function_; + TaskRunner* const task_runner_; MdnsRandom* const random_delay_; + + friend class MdnsResponderTest; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc new file mode 100644 index 00000000000..037a71fec94 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc @@ -0,0 +1,762 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/mdns_responder.h" + +#include <utility> + +#include "discovery/mdns/mdns_probe_manager.h" +#include "discovery/mdns/mdns_random.h" +#include "discovery/mdns/mdns_receiver.h" +#include "discovery/mdns/mdns_records.h" +#include "discovery/mdns/mdns_sender.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "platform/test/fake_udp_socket.h" + +namespace openscreen { +namespace discovery { +namespace { + +constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000); + +bool ContainsRecordType(const std::vector<MdnsRecord>& records, DnsType type) { + return std::find_if(records.begin(), records.end(), + [type](const MdnsRecord& record) { + return record.dns_type() == type; + }) != records.end(); +} + +void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) { + ASSERT_EQ(message.answers().size(), size_t{1}); + const MdnsRecord record = message.answers()[0]; + + ASSERT_EQ(record.dns_type(), DnsType::kNSEC); + const NsecRecordRdata& rdata = absl::get<NsecRecordRdata>(record.rdata()); + + ASSERT_EQ(rdata.types().size(), size_t{1}); + EXPECT_EQ(rdata.types()[0], type); +} + +void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) { + ASSERT_EQ(record.dns_type(), DnsType::kPTR); + const PtrRecordRdata& rdata = absl::get<PtrRecordRdata>(record.rdata()); + + EXPECT_EQ(rdata.ptr_domain(), domain); +} + +void ExpectContainsNsecRecordType(const std::vector<MdnsRecord>& records, + DnsType type) { + auto it = std::find_if( + records.begin(), records.end(), [type](const MdnsRecord& record) { + if (record.dns_type() != DnsType::kNSEC) { + return false; + } + + const NsecRecordRdata& rdata = + absl::get<NsecRecordRdata>(record.rdata()); + return rdata.types().size() == 1 && rdata.types()[0] == type; + }); + EXPECT_TRUE(it != records.end()); +} + +} // namespace + +using testing::_; +using testing::Args; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +class MockRecordHandler : public MdnsResponder::RecordHandler { + public: + void AddRecord(MdnsRecord record) { records_.push_back(record); } + + MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass)); + + std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name, + DnsType type, + DnsClass clazz) override { + std::vector<MdnsRecord::ConstRef> records; + for (const auto& record : records_) { + if (type == DnsType::kANY || record.dns_type() == type) { + records.push_back(record); + } + } + + return records; + } + + std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override { + std::vector<MdnsRecord::ConstRef> records; + for (const auto& record : records_) { + if (record.dns_type() == DnsType::kPTR) { + records.push_back(record); + } + } + + return records; + } + + private: + std::vector<MdnsRecord> records_; +}; + +class MockMdnsSender : public MdnsSender { + public: + explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} + + MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); + MOCK_METHOD2(SendMessage, + Error(const MdnsMessage& message, const IPEndpoint& endpoint)); +}; + +class MockProbeManager : public MdnsProbeManager { + public: + MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&)); + MOCK_METHOD2(RespondToProbeQuery, + void(const MdnsMessage&, const IPEndpoint&)); +}; + +class MdnsResponderTest : public testing::Test { + public: + MdnsResponderTest() + : clock_(Clock::now()), + task_runner_(&clock_), + socket_(&task_runner_), + sender_(&socket_), + responder_(&record_handler_, + &probe_manager_, + &sender_, + &receiver_, + &task_runner_, + &random_) {} + + protected: + MdnsRecord GetFakePtrRecord(const DomainName& target) { + DomainName name(++target.labels().begin(), target.labels().end()); + PtrRecordRdata rdata(target); + return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN, + RecordType::kUnique, std::chrono::seconds(0), rdata); + } + + MdnsRecord GetFakeSrvRecord(const DomainName& name) { + SrvRecordRdata rdata(0, 0, 80, name); + return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique, + std::chrono::seconds(0), rdata); + } + + MdnsRecord GetFakeTxtRecord(const DomainName& name) { + TxtRecordRdata rdata; + return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique, + std::chrono::seconds(0), rdata); + } + + MdnsRecord GetFakeARecord(const DomainName& name) { + ARecordRdata rdata(IPAddress(192, 168, 0, 0)); + return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique, + std::chrono::seconds(0), rdata); + } + + MdnsRecord GetFakeAAAARecord(const DomainName& name) { + AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8)); + return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, + std::chrono::seconds(0), rdata); + } + + void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) { + responder_.OnMessageReceived(message, src); + } + + void QueryForRecordTypeWhenNonePresent(DnsType type) { + MdnsQuestion question(domain_, type, DnsClass::kANY, + ResponseType::kMulticast); + MdnsMessage message(0, MessageType::Query); + message.AddQuestion(question); + + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([type](const MdnsMessage& msg) -> Error { + CheckSingleNsecRecordType(msg, type); + return Error::None(); + }); + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + OnMessageReceived(message, endpoint_); + } + + MdnsMessage CreateMulticastMdnsQuery(DnsType type) { + MdnsQuestion question(domain_, type, DnsClass::kANY, + ResponseType::kMulticast); + MdnsMessage message(0, MessageType::Query); + message.AddQuestion(std::move(question)); + + return message; + } + + MdnsMessage CreateTypeEnumerationQuery() { + MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR, + DnsClass::kANY, ResponseType::kMulticast); + MdnsMessage message(0, MessageType::Query); + message.AddQuestion(std::move(question)); + + return message; + } + + FakeClock clock_; + FakeTaskRunner task_runner_; + FakeUdpSocket socket_; + StrictMock<MockMdnsSender> sender_; + StrictMock<MockRecordHandler> record_handler_; + StrictMock<MockProbeManager> probe_manager_; + MdnsReceiver receiver_; + MdnsRandom random_; + MdnsResponder responder_; + + DomainName domain_{"instance", "_googlecast", "_tcp", "local"}; + DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"}; + IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80}; +}; + +// Validate that when records may be sent from multiple receivers, the broadcast +// is delayed and it is not delayed otherwise. +TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + OnMessageReceived(message, endpoint_); + testing::Mock::VerifyAndClearExpectations(&sender_); + testing::Mock::VerifyAndClearExpectations(&record_handler_); + testing::Mock::VerifyAndClearExpectations(&probe_manager_); + + EXPECT_CALL(sender_, SendMulticast(_)).Times(0); + clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); +} + +TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)).Times(0); + OnMessageReceived(message, endpoint_); + testing::Mock::VerifyAndClearExpectations(&sender_); + testing::Mock::VerifyAndClearExpectations(&record_handler_); + testing::Mock::VerifyAndClearExpectations(&probe_manager_); + + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); +} + +TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY, + ResponseType::kMulticast); + message.AddQuestion(std::move(question2)); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + OnMessageReceived(message, endpoint_); + testing::Mock::VerifyAndClearExpectations(&sender_); + testing::Mock::VerifyAndClearExpectations(&record_handler_); + testing::Mock::VerifyAndClearExpectations(&probe_manager_); + + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); +} + +// Validate that the correct messaging scheme (unicast vs multicast) is used. +TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) { + MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY, + ResponseType::kUnicast); + MdnsMessage message(0, MessageType::Query); + message.AddQuestion(question); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + OnMessageReceived(message, endpoint_); +} + +// Validate that records are added as expected based on the query type, and that +// additional records are populated as specified in RFC 6762 and 6763. +TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{4}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT)); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); + EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR)); + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, PtrQueryResultsApplied) { + DomainName ptr_domain{"_googlecast", "_tcp", "local"}; + MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY, + ResponseType::kMulticast); + MdnsMessage message(0, MessageType::Query); + message.AddQuestion(question); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{4}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); + + const auto& records = message.additional_records(); + EXPECT_EQ(records.size(), size_t{4}); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV)); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT)); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, SrvQueryResultsApplied) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{2}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); + + const auto& records = message.additional_records(); + EXPECT_EQ(records.size(), size_t{2}); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, AQueryResultsApplied) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{1}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); + + const auto& records = message.additional_records(); + EXPECT_EQ(records.size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{1}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); + + const auto& records = message.additional_records(); + EXPECT_EQ(records.size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); + EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); + MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); + message.AddAnswer(aaaa_record); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(aaaa_record); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); + MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); + message.AddAnswer(aaaa_record); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + record_handler_.AddRecord(aaaa_record); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + EXPECT_EQ(message.additional_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +// Validate NSEC records are used correctly. +TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) { + QueryForRecordTypeWhenNonePresent(DnsType::kANY); + QueryForRecordTypeWhenNonePresent(DnsType::kSRV); + QueryForRecordTypeWhenNonePresent(DnsType::kTXT); + QueryForRecordTypeWhenNonePresent(DnsType::kA); + QueryForRecordTypeWhenNonePresent(DnsType::kAAAA); +} + +TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeAAAARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); + + EXPECT_EQ(message.additional_records().size(), size_t{1}); + ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); + + EXPECT_EQ(message.additional_records().size(), size_t{1}); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kAAAA); + + return Error::None(); + }); + + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); + + EXPECT_EQ(message.additional_records().size(), size_t{2}); + ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kAAAA); + + return Error::None(); + }); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); + + EXPECT_EQ(message.additional_records().size(), size_t{2}); + EXPECT_TRUE( + ContainsRecordType(message.additional_records(), DnsType::kA)); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kAAAA); + + return Error::None(); + }); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); + + EXPECT_EQ(message.additional_records().size(), size_t{2}); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kTXT); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kSRV); + + return Error::None(); + }); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); + + EXPECT_EQ(message.additional_records().size(), size_t{2}); + EXPECT_TRUE( + ContainsRecordType(message.additional_records(), DnsType::kTXT)); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kSRV); + + return Error::None(); + }); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) { + MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + record_handler_.AddRecord(GetFakePtrRecord(domain_)); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); + + EXPECT_EQ(message.additional_records().size(), size_t{4}); + EXPECT_TRUE( + ContainsRecordType(message.additional_records(), DnsType::kSRV)); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kTXT); + ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); + ExpectContainsNsecRecordType(message.additional_records(), + DnsType::kAAAA); + + return Error::None(); + }); + OnMessageReceived(message, endpoint_); +} + +TEST_F(MdnsResponderTest, EnumerateAllQuery) { + MdnsMessage message = CreateTypeEnumerationQuery(); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + const auto ptr = GetFakePtrRecord(domain_); + record_handler_.AddRecord(ptr); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + OnMessageReceived(message, endpoint_); + + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce([this, &ptr](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{0}); + EXPECT_EQ(message.authority_records().size(), size_t{0}); + + EXPECT_EQ(message.answers().size(), size_t{1}); + EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); + EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_); + CheckPtrDomain(message.answers()[0], ptr.name()); + return Error::None(); + }); + clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); +} + +TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) { + MdnsMessage message = CreateTypeEnumerationQuery(); + + EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); + EXPECT_CALL(record_handler_, HasRecords(_, _, _)) + .WillRepeatedly(Return(true)); + const auto ptr = GetFakePtrRecord(domain_); + record_handler_.AddRecord(GetFakeSrvRecord(domain_)); + record_handler_.AddRecord(GetFakeTxtRecord(domain_)); + record_handler_.AddRecord(GetFakeARecord(domain_)); + OnMessageReceived(message, endpoint_); + clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc index ac6305feba4..607c4227a37 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc @@ -4,39 +4,28 @@ #include "discovery/mdns/mdns_sender.h" +#include <iostream> + #include "discovery/mdns/mdns_writer.h" +#include "platform/api/udp_socket.h" namespace openscreen { namespace discovery { -namespace { - -const IPEndpoint& GetIPv6MdnsMulticastEndpoint() { - static IPEndpoint endpoint{.address = IPAddress(kDefaultMulticastGroupIPv6), - .port = kDefaultMulticastPort}; - return endpoint; -} - -const IPEndpoint& GetIPv4MdnsMulticastEndpoint() { - static IPEndpoint endpoint{.address = IPAddress(kDefaultMulticastGroupIPv4), - .port = kDefaultMulticastPort}; - return endpoint; -} - -} // namespace - MdnsSender::MdnsSender(UdpSocket* socket) : socket_(socket) { OSP_DCHECK(socket_ != nullptr); } +MdnsSender::~MdnsSender() = default; + Error MdnsSender::SendMulticast(const MdnsMessage& message) { const IPEndpoint& endpoint = socket_->IsIPv6() - ? GetIPv6MdnsMulticastEndpoint() - : GetIPv4MdnsMulticastEndpoint(); - return SendUnicast(message, endpoint); + ? kDefaultMulticastGroupIPv6Endpoint + : kDefaultMulticastGroupIPv4Endpoint; + return SendMessage(message, endpoint); } -Error MdnsSender::SendUnicast(const MdnsMessage& message, +Error MdnsSender::SendMessage(const MdnsMessage& message, const IPEndpoint& endpoint) { // Always try to write the message into the buffer even if MaxWireSize is // greater than maximum message size. Domain name compression might reduce the @@ -52,5 +41,9 @@ Error MdnsSender::SendUnicast(const MdnsMessage& message, return Error::Code::kNone; } +void MdnsSender::OnSendError(UdpSocket* socket, Error error) { + OSP_LOG_ERROR << "Error sending packet"; +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h index b51eb12228e..97356ced617 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h @@ -16,19 +16,21 @@ class MdnsMessage; class MdnsSender { public: - using UdpSocket = openscreen::platform::UdpSocket; - // MdnsSender does not own |socket| and expects that its lifetime exceeds the // lifetime of MdnsSender. explicit MdnsSender(UdpSocket* socket); MdnsSender(const MdnsSender& other) = delete; MdnsSender(MdnsSender&& other) noexcept = delete; + virtual ~MdnsSender(); + MdnsSender& operator=(const MdnsSender& other) = delete; MdnsSender& operator=(MdnsSender&& other) noexcept = delete; - ~MdnsSender() = default; - Error SendMulticast(const MdnsMessage& message); - Error SendUnicast(const MdnsMessage& message, const IPEndpoint& endpoint); + virtual Error SendMulticast(const MdnsMessage& message); + virtual Error SendMessage(const MdnsMessage& message, + const IPEndpoint& endpoint); + + void OnSendError(UdpSocket* socket, Error error); private: UdpSocket* const socket_; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc index da67ef02bea..b125b46fd5a 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc @@ -4,32 +4,31 @@ #include "discovery/mdns/mdns_sender.h" +#include <memory> +#include <vector> + #include "discovery/mdns/mdns_records.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/test/fake_udp_socket.h" +#include "platform/test/mock_udp_socket.h" namespace openscreen { namespace discovery { -using openscreen::platform::FakeUdpSocket; using testing::_; using testing::Args; using testing::Return; +using testing::StrictMock; +using testing::WithArgs; namespace { -MATCHER_P( - VoidPointerMatchesBytes, - expected_data, - "Matches data at the pointer against the provided C-style byte array.") { - const uint8_t* actual_data = static_cast<const uint8_t*>(arg); +ACTION_P(VoidPointerMatchesBytes, expected_data) { + const uint8_t* actual_data = static_cast<const uint8_t*>(arg0); for (size_t i = 0; i < expected_data.size(); ++i) { - if (actual_data[i] != expected_data[i]) { - return false; - } + EXPECT_EQ(actual_data[i], expected_data[i]); } - return true; } } // namespace @@ -103,54 +102,39 @@ class MdnsSenderTest : public testing::Test { IPEndpoint ipv6_multicast_endpoint_; }; -TEST_F(MdnsSenderTest, SendMulticastIPv4) { - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV4); - MdnsSender sender(socket_info.get()); - socket_info->EnqueueSendResult(Error::Code::kNone); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0); - EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone); - EXPECT_EQ(socket_info->send_queue_size(), size_t{0}); -} - -TEST_F(MdnsSenderTest, SendMulticastIPv6) { - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV6); - MdnsSender sender(socket_info.get()); - socket_info->EnqueueSendResult(Error::Code::kNone); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0); +TEST_F(MdnsSenderTest, SendMulticast) { + StrictMock<MockUdpSocket> socket; + EXPECT_CALL(socket, IsIPv4()).WillRepeatedly(Return(true)); + EXPECT_CALL(socket, IsIPv6()).WillRepeatedly(Return(true)); + MdnsSender sender(&socket); + EXPECT_CALL(socket, SendMessage(_, kQueryBytes.size(), _)) + .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kQueryBytes))); EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone); - EXPECT_EQ(socket_info->send_queue_size(), size_t{0}); } TEST_F(MdnsSenderTest, SendUnicastIPv4) { IPEndpoint endpoint{.address = IPAddress{192, 168, 1, 1}, .port = 31337}; - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV4); - MdnsSender sender(socket_info.get()); - socket_info->EnqueueSendResult(Error::Code::kNone); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0); - EXPECT_EQ(sender.SendUnicast(response_message_, endpoint), + StrictMock<MockUdpSocket> socket; + MdnsSender sender(&socket); + EXPECT_CALL(socket, SendMessage(_, kResponseBytes.size(), _)) + .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kResponseBytes))); + EXPECT_EQ(sender.SendMessage(response_message_, endpoint), Error::Code::kNone); - EXPECT_EQ(socket_info->send_queue_size(), size_t{0}); } TEST_F(MdnsSenderTest, SendUnicastIPv6) { - constexpr uint8_t kIPv6AddressBytes[] = { - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + constexpr uint16_t kIPv6AddressHextets[] = { + 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; - IPEndpoint endpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337}; - - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV6); - MdnsSender sender(socket_info.get()); - socket_info->EnqueueSendResult(Error::Code::kNone); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0); - EXPECT_EQ(sender.SendUnicast(response_message_, endpoint), + IPEndpoint endpoint{.address = IPAddress(kIPv6AddressHextets), .port = 31337}; + + StrictMock<MockUdpSocket> socket; + MdnsSender sender(&socket); + EXPECT_CALL(socket, SendMessage(_, kResponseBytes.size(), _)) + .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kResponseBytes))); + EXPECT_EQ(sender.SendMessage(response_message_, endpoint), Error::Code::kNone); - EXPECT_EQ(socket_info->send_queue_size(), size_t{0}); } TEST_F(MdnsSenderTest, MessageTooBig) { @@ -159,25 +143,24 @@ TEST_F(MdnsSenderTest, MessageTooBig) { big_message_.AddQuestion(a_question_); big_message_.AddAnswer(a_record_); } - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV4); - MdnsSender sender(socket_info.get()); - socket_info->EnqueueSendResult(Error::Code::kNone); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0); + + StrictMock<MockUdpSocket> socket; + EXPECT_CALL(socket, IsIPv4()).WillRepeatedly(Return(true)); + EXPECT_CALL(socket, IsIPv6()).WillRepeatedly(Return(true)); + MdnsSender sender(&socket); EXPECT_EQ(sender.SendMulticast(big_message_), Error::Code::kInsufficientBuffer); - EXPECT_EQ(socket_info->send_queue_size(), size_t{1}); } TEST_F(MdnsSenderTest, ReturnsErrorOnSocketFailure) { - std::unique_ptr<FakeUdpSocket> socket_info = - FakeUdpSocket::CreateDefault(IPAddress::Version::kV4); - MdnsSender sender(socket_info.get()); + FakeUdpSocket::MockClient socket_client; + FakeUdpSocket socket(nullptr, &socket_client); + MdnsSender sender(&socket); Error error = Error(Error::Code::kConnectionFailed, "error message"); - socket_info->EnqueueSendResult(error); - EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, error)).Times(1); + socket.EnqueueSendResult(error); + EXPECT_CALL(socket_client, OnSendError(_, error)).Times(1); EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone); - EXPECT_EQ(socket_info->send_queue_size(), size_t{0}); + EXPECT_EQ(socket.send_queue_size(), size_t{0}); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc index df0fdb9f977..61762510dbb 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc @@ -4,34 +4,160 @@ #include "discovery/mdns/mdns_service_impl.h" +#include <memory> + +#include "discovery/common/config.h" +#include "discovery/common/reporting_client.h" +#include "discovery/mdns/mdns_records.h" +#include "discovery/mdns/public/mdns_constants.h" + namespace openscreen { namespace discovery { // static -std::unique_ptr<MdnsService> MdnsService::Create(TaskRunner* task_runner) { - return std::make_unique<MdnsServiceImpl>(); +std::unique_ptr<MdnsService> MdnsService::Create( + TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config, + NetworkInterfaceIndex network_interface, + SupportedNetworkAddressFamily supported_address_types) { + return std::make_unique<MdnsServiceImpl>( + task_runner, Clock::now, reporting_client, config, network_interface, + supported_address_types); +} + +MdnsServiceImpl::MdnsServiceImpl( + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + ReportingClient* reporting_client, + const Config& config, + NetworkInterfaceIndex network_interface, + SupportedNetworkAddressFamily supported_address_types) + : task_runner_(task_runner), + now_function_(now_function), + reporting_client_(reporting_client) { + OSP_DCHECK(task_runner_); + OSP_DCHECK(reporting_client_); + OSP_DCHECK(supported_address_types); + + // Create all UDP sockets needed for this object. They should not yet be bound + // so that they do not send or receive data until the objects on which their + // callback depends is initialized. + if (supported_address_types & kUseIpV4Multicast) { + ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( + task_runner, this, kDefaultMulticastGroupIPv4Endpoint); + OSP_DCHECK(!socket.is_error()); + OSP_DCHECK(socket.value().get()); + OSP_DCHECK(socket.value()->IsIPv4()); + + socket_v4_ = std::move(socket.value()); + socket_v4_->SetMulticastOutboundInterface(network_interface); + socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4, + network_interface); + socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4, + network_interface); + } + + if (supported_address_types & kUseIpV6Multicast) { + ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( + task_runner, this, kDefaultMulticastGroupIPv6Endpoint); + OSP_DCHECK(!socket.is_error()); + OSP_DCHECK(socket.value().get()); + OSP_DCHECK(socket.value()->IsIPv6()); + + socket_v6_ = std::move(socket.value()); + socket_v6_->SetMulticastOutboundInterface(network_interface); + socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6, + network_interface); + socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6, + network_interface); + } + + // Initialize objects which depend on the above sockets. + UdpSocket* socket_ptr = + socket_v4_.get() ? socket_v4_.get() : socket_v6_.get(); + OSP_DCHECK(socket_ptr); + sender_ = std::make_unique<MdnsSender>(socket_ptr); + if (config.enable_querying) { + querier_ = std::make_unique<MdnsQuerier>( + sender_.get(), &receiver_, task_runner_, now_function_, &random_delay_, + reporting_client_, config); + } + if (config.enable_publication) { + probe_manager_ = std::make_unique<MdnsProbeManagerImpl>( + sender_.get(), &receiver_, &random_delay_, task_runner_, now_function_); + publisher_ = + std::make_unique<MdnsPublisher>(sender_.get(), probe_manager_.get(), + task_runner_, now_function_, config); + responder_ = std::make_unique<MdnsResponder>( + publisher_.get(), probe_manager_.get(), sender_.get(), &receiver_, + task_runner_, &random_delay_); + } + + receiver_.Start(); + + // Initialize all sockets to start sending/receiving data. Now that the above + // objects have all been created, it they should be able to safely do so. + // NOTE: Although only one of these sockets is used for sending, both will be + // used for reading on the mDNS v4 and v6 addresses and ports. + if (socket_v4_.get()) { + socket_v4_->Bind(); + } + if (socket_v6_.get()) { + socket_v6_->Bind(); + } } +MdnsServiceImpl::~MdnsServiceImpl() = default; + void MdnsServiceImpl::StartQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) { - // TODO(yakimakha): Implement this method + return querier_->StartQuery(name, dns_type, dns_class, callback); } void MdnsServiceImpl::StopQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) { - // TODO(yakimakha): Implement this method + return querier_->StopQuery(name, dns_type, dns_class, callback); +} + +void MdnsServiceImpl::ReinitializeQueries(const DomainName& name) { + querier_->ReinitializeQueries(name); +} + +Error MdnsServiceImpl::StartProbe(MdnsDomainConfirmedProvider* callback, + DomainName requested_name, + IPAddress address) { + return probe_manager_->StartProbe(callback, std::move(requested_name), + std::move(address)); +} + +Error MdnsServiceImpl::RegisterRecord(const MdnsRecord& record) { + return publisher_->RegisterRecord(record); +} + +Error MdnsServiceImpl::UpdateRegisteredRecord(const MdnsRecord& old_record, + const MdnsRecord& new_record) { + return publisher_->UpdateRegisteredRecord(old_record, new_record); +} + +Error MdnsServiceImpl::UnregisterRecord(const MdnsRecord& record) { + return publisher_->UnregisterRecord(record); +} + +void MdnsServiceImpl::OnError(UdpSocket* socket, Error error) { + reporting_client_->OnFatalError(error); } -void MdnsServiceImpl::RegisterRecord(const MdnsRecord& record) { - // TODO(yakimakha): Implement this method +void MdnsServiceImpl::OnSendError(UdpSocket* socket, Error error) { + sender_->OnSendError(socket, error); } -void MdnsServiceImpl::DeregisterRecord(const MdnsRecord& record) { - // TODO(yakimakha): Implement this method +void MdnsServiceImpl::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) { + receiver_.OnRead(socket, std::move(packet)); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h index 072add12e23..7da1c1a9c24 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h @@ -5,26 +5,86 @@ #ifndef DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_ #define DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_ +#include "discovery/mdns/mdns_domain_confirmed_provider.h" +#include "discovery/mdns/mdns_probe_manager.h" +#include "discovery/mdns/mdns_publisher.h" +#include "discovery/mdns/mdns_querier.h" +#include "discovery/mdns/mdns_random.h" +#include "discovery/mdns/mdns_reader.h" +#include "discovery/mdns/mdns_receiver.h" +#include "discovery/mdns/mdns_records.h" +#include "discovery/mdns/mdns_responder.h" +#include "discovery/mdns/mdns_sender.h" +#include "discovery/mdns/mdns_writer.h" +#include "discovery/mdns/public/mdns_constants.h" #include "discovery/mdns/public/mdns_service.h" +#include "platform/api/udp_socket.h" namespace openscreen { + +class TaskRunner; + namespace discovery { -class MdnsServiceImpl : public MdnsService { +struct Config; +class NetworkConfig; +class ReportingClient; + +class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { public: + // |task_runner|, |reporting_client|, and |config| must exist for the duration + // of this instance's life. + MdnsServiceImpl(TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + ReportingClient* reporting_client, + const Config& config, + NetworkInterfaceIndex network_interface, + SupportedNetworkAddressFamily supported_address_types); + ~MdnsServiceImpl() override; + + // MdnsService Overrides. void StartQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) override; - void StopQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) override; + void ReinitializeQueries(const DomainName& name) override; + Error StartProbe(MdnsDomainConfirmedProvider* callback, + DomainName requested_name, + IPAddress address) override; + + Error RegisterRecord(const MdnsRecord& record) override; + Error UpdateRegisteredRecord(const MdnsRecord& old_record, + const MdnsRecord& new_record) override; + Error UnregisterRecord(const MdnsRecord& record) override; + + // UdpSocket::Client overrides. + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; + + private: + TaskRunner* const task_runner_; + ClockNowFunctionPtr now_function_; + ReportingClient* const reporting_client_; + + MdnsRandom random_delay_; + MdnsReceiver receiver_; - void RegisterRecord(const MdnsRecord& record) override; + // Sockets to send and receive mDNS Data according to RFC 6762. + std::unique_ptr<UdpSocket> socket_v4_; + std::unique_ptr<UdpSocket> socket_v6_; - void DeregisterRecord(const MdnsRecord& record) override; + // unique_ptrs are used for the below objects so that they can be initialized + // in the body of the ctor, after send_socket is initialized. + std::unique_ptr<MdnsSender> sender_; + std::unique_ptr<MdnsQuerier> querier_; + std::unique_ptr<MdnsProbeManagerImpl> probe_manager_; + std::unique_ptr<MdnsPublisher> publisher_; + std::unique_ptr<MdnsResponder> responder_; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc index 28b05b7fa73..e250582197b 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc @@ -5,7 +5,9 @@ #include "discovery/mdns/mdns_trackers.h" #include <array> +#include <limits> +#include "discovery/common/config.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_sender.h" @@ -39,6 +41,18 @@ bool IsGoodbyeRecord(const MdnsRecord& record) { return record.ttl() == std::chrono::seconds{0}; } +bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) { + if (record.dns_type() != DnsType::kNSEC) { + return false; + } + + const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types(); + return std::find_if(nsec_types.begin(), nsec_types.end(), + [dns_type](DnsType type) { + return type == dns_type || type == DnsType::kANY; + }) != nsec_types.end(); +} + // RFC 6762 Section 10.1 // https://tools.ietf.org/html/rfc6762#section-10.1 // In case of a goodbye record, the querier should set TTL to 1 second @@ -49,46 +63,131 @@ constexpr std::chrono::seconds kGoodbyeRecordTtl{1}; MdnsTracker::MdnsTracker(MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay) + MdnsRandom* random_delay, + TrackerType tracker_type) : sender_(sender), task_runner_(task_runner), now_function_(now_function), send_alarm_(now_function, task_runner), - random_delay_(random_delay) { - OSP_DCHECK(task_runner); - OSP_DCHECK(now_function); - OSP_DCHECK(random_delay); - OSP_DCHECK(sender); + random_delay_(random_delay), + tracker_type_(tracker_type) { + OSP_DCHECK(task_runner_); + OSP_DCHECK(now_function_); + OSP_DCHECK(random_delay_); + OSP_DCHECK(sender_); +} + +MdnsTracker::~MdnsTracker() { + send_alarm_.Cancel(); + + for (const MdnsTracker* node : adjacent_nodes_) { + node->RemovedReverseAdjacency(this); + } +} + +bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const { + OSP_DCHECK(node); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node); + if (it != adjacent_nodes_.end()) { + return false; + } + + adjacent_nodes_.push_back(node); + node->AddReverseAdjacency(this); + return true; +} + +bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const { + OSP_DCHECK(node); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + + auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node); + if (it == adjacent_nodes_.end()) { + return false; + } + + adjacent_nodes_.erase(it); + node->RemovedReverseAdjacency(this); + return true; +} + +void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const { + OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) == + adjacent_nodes_.end()); + + adjacent_nodes_.push_back(node); +} + +void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const { + auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node); + OSP_DCHECK(it != adjacent_nodes_.end()); + + adjacent_nodes_.erase(it); } MdnsRecordTracker::MdnsRecordTracker( MdnsRecord record, + DnsType dns_type, MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, MdnsRandom* random_delay, - std::function<void(const MdnsRecord&)> record_expired_callback) - : MdnsTracker(sender, task_runner, now_function, random_delay), + RecordExpiredCallback record_expired_callback) + : MdnsTracker(sender, + task_runner, + now_function, + random_delay, + TrackerType::kRecordTracker), record_(std::move(record)), + dns_type_(dns_type), start_time_(now_function_()), - record_expired_callback_(record_expired_callback) { - OSP_DCHECK(record_expired_callback); + record_expired_callback_(std::move(record_expired_callback)) { + OSP_DCHECK(record_expired_callback_); + + // RecordTrackers cannot be created for tracking NSEC types or ANY types. + OSP_DCHECK(dns_type_ != DnsType::kNSEC); + OSP_DCHECK(dns_type_ != DnsType::kANY); - send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime()); + // Validate that, if the provided |record| is an NSEC record, then it provides + // a negative response for |dns_type|. + OSP_DCHECK(record_.dns_type() != DnsType::kNSEC || + IsNegativeResponseForType(record_, dns_type_)); + + ScheduleFollowUpQuery(); } +MdnsRecordTracker::~MdnsRecordTracker() = default; + ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update( const MdnsRecord& new_record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - bool has_same_rdata = record_.rdata() == new_record.rdata(); + const bool has_same_rdata = record_.dns_type() == new_record.dns_type() && + record_.rdata() == new_record.rdata(); + const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC; + const bool current_is_negative_response = + record_.dns_type() == DnsType::kNSEC; + + if ((record_.dns_class() != new_record.dns_class()) || + (record_.name() != new_record.name())) { + // The new record has been passed to a wrong tracker. + return Error::Code::kParameterInvalid; + } + + // New response record must correspond to the correct type. + if ((!new_is_negative_response && new_record.dns_type() != dns_type_) || + (new_is_negative_response && + !IsNegativeResponseForType(new_record, dns_type_))) { + // The new record has been passed to a wrong tracker. + return Error::Code::kParameterInvalid; + } // Goodbye records must have the same RDATA but TTL of 0. - // RFC 6762 Section 10.1 + // RFC 6762 Section 10.1. // https://tools.ietf.org/html/rfc6762#section-10.1 - if ((record_.dns_type() != new_record.dns_type()) || - (record_.dns_class() != new_record.dns_class()) || - (record_.name() != new_record.name()) || - (IsGoodbyeRecord(new_record) && !has_same_rdata)) { + if (!new_is_negative_response && !current_is_negative_response && + IsGoodbyeRecord(new_record) && !has_same_rdata) { // The new record has been passed to a wrong tracker. return Error::Code::kParameterInvalid; } @@ -99,9 +198,9 @@ ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update( new_record.dns_class(), new_record.record_type(), kGoodbyeRecordTtl, new_record.rdata()); - // Goodbye records do not need to be requeried, set the attempt count to the - // last item, which is 100% of TTL, i.e. record expiration. - attempt_count_ = openscreen::countof(kTtlFractions) - 1; + // Goodbye records do not need to be re-queried, set the attempt count to + // the last item, which is 100% of TTL, i.e. record expiration. + attempt_count_ = countof(kTtlFractions) - 1; } else { record_ = new_record; attempt_count_ = 0; @@ -109,11 +208,21 @@ ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update( } start_time_ = now_function_(); - send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime()); + ScheduleFollowUpQuery(); return result; } +bool MdnsRecordTracker::AddAssociatedQuery( + const MdnsQuestionTracker* question_tracker) const { + return AddAdjacentNode(question_tracker); +} + +bool MdnsRecordTracker::RemoveAssociatedQuery( + const MdnsQuestionTracker* question_tracker) const { + return RemoveAdjacentNode(question_tracker); +} + void MdnsRecordTracker::ExpireSoon() { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); @@ -122,35 +231,55 @@ void MdnsRecordTracker::ExpireSoon() { record_.record_type(), kGoodbyeRecordTtl, record_.rdata()); // Set the attempt count to the last item, which is 100% of TTL, i.e. record - // expiration, to prevent any requeries - attempt_count_ = openscreen::countof(kTtlFractions) - 1; + // expiration, to prevent any re-queries + attempt_count_ = countof(kTtlFractions) - 1; start_time_ = now_function_(); - send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime()); + ScheduleFollowUpQuery(); +} + +void MdnsRecordTracker::ExpireNow() { + record_expired_callback_(this, record_); +} + +bool MdnsRecordTracker::IsNearingExpiry() const { + return (now_function_() - start_time_) > record_.ttl() / 2; } -void MdnsRecordTracker::SendQuery() { +bool MdnsRecordTracker::SendQuery() const { const Clock::time_point expiration_time = start_time_ + record_.ttl(); - const bool is_expired = (now_function_() >= expiration_time); + bool is_expired = (now_function_() >= expiration_time); if (!is_expired) { - MdnsQuestion question(record_.name(), record_.dns_type(), - record_.dns_class(), ResponseType::kMulticast); - MdnsMessage message(CreateMessageId(), MessageType::Query); - message.AddQuestion(std::move(question)); - sender_->SendMulticast(message); - send_alarm_.Schedule([this] { MdnsRecordTracker::SendQuery(); }, - GetNextSendTime()); + for (const MdnsTracker* tracker : adjacent_nodes()) { + tracker->SendQuery(); + } } else { - record_expired_callback_(record_); + record_expired_callback_(this, record_); } + + return !is_expired; +} + +void MdnsRecordTracker::ScheduleFollowUpQuery() { + send_alarm_.Schedule( + [this] { + if (SendQuery()) { + ScheduleFollowUpQuery(); + } + }, + GetNextSendTime()); +} + +std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const { + return {record_}; } -openscreen::platform::Clock::time_point MdnsRecordTracker::GetNextSendTime() { - OSP_DCHECK(attempt_count_ < openscreen::countof(kTtlFractions)); +Clock::time_point MdnsRecordTracker::GetNextSendTime() { + OSP_DCHECK(attempt_count_ < countof(kTtlFractions)); double ttl_fraction = kTtlFractions[attempt_count_++]; // Do not add random variation to the expiration time (last fraction of TTL) - if (attempt_count_ != openscreen::countof(kTtlFractions)) { + if (attempt_count_ != countof(kTtlFractions)) { ttl_fraction += random_delay_->GetRecordTtlVariation(); } @@ -163,25 +292,134 @@ MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question, MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay) - : MdnsTracker(sender, task_runner, now_function, random_delay), + MdnsRandom* random_delay, + const Config& config, + QueryType query_type) + : MdnsTracker(sender, + task_runner, + now_function, + random_delay, + TrackerType::kQuestionTracker), question_(std::move(question)), - send_delay_(kMinimumQueryInterval) { + send_delay_(kMinimumQueryInterval), + query_type_(query_type), + maximum_announcement_count_(config.new_query_announcement_count < 0 + ? INT_MAX + : config.new_query_announcement_count) { + // Initialize the last send time to time_point::min() so that the next call to + // SendQuery() is guaranteed to query the network. + last_send_time_ = TrivialClockTraits::time_point::min(); + // The initial query has to be sent after a random delay of 20-120 // milliseconds. - const Clock::duration delay = random_delay_->GetInitialQueryDelay(); - send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); }, - now_function_() + delay); + if (announcements_so_far_ < maximum_announcement_count_) { + announcements_so_far_++; + + if (query_type_ == QueryType::kOneShot) { + task_runner_->PostTask([this] { MdnsQuestionTracker::SendQuery(); }); + } else { + OSP_DCHECK(query_type_ == QueryType::kContinuous); + send_alarm_.ScheduleFromNow( + [this]() { + MdnsQuestionTracker::SendQuery(); + ScheduleFollowUpQuery(); + }, + random_delay_->GetInitialQueryDelay()); + } + } +} + +MdnsQuestionTracker::~MdnsQuestionTracker() = default; + +bool MdnsQuestionTracker::AddAssociatedRecord( + const MdnsRecordTracker* record_tracker) const { + return AddAdjacentNode(record_tracker); } -void MdnsQuestionTracker::SendQuery() { +bool MdnsQuestionTracker::RemoveAssociatedRecord( + const MdnsRecordTracker* record_tracker) const { + return RemoveAdjacentNode(record_tracker); +} + +std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const { + std::vector<MdnsRecord> records; + for (const MdnsTracker* tracker : adjacent_nodes()) { + OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker); + + // This call cannot result in an infinite loop because MdnsRecordTracker + // instances only return a single record from this call. + std::vector<MdnsRecord> node_records = tracker->GetRecords(); + OSP_DCHECK(node_records.size() == 1); + + records.push_back(std::move(node_records[0])); + } + + return records; +} + +bool MdnsQuestionTracker::SendQuery() const { + // NOTE: The RFC does not specify the minimum interval between queries for + // multiple records of the same query when initiated for different reasons + // (such as for different record refreshes or for one record refresh and the + // periodic re-querying for a continuous query). For this reason, a constant + // outside of scope of the RFC has been chosen. + TrivialClockTraits::time_point now = now_function_(); + if (now < last_send_time_ + kMinimumQueryInterval) { + return true; + } + last_send_time_ = now; + MdnsMessage message(CreateMessageId(), MessageType::Query); message.AddQuestion(question_); - // TODO(yakimakha): Implement known-answer suppression by adding known - // answers to the question + + // Send the message and additional known answer packets as needed. + for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) { + OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker); + + const MdnsRecordTracker* record_tracker = + static_cast<const MdnsRecordTracker*>(*it); + if (record_tracker->IsNearingExpiry()) { + it++; + continue; + } + + // A record tracker should only contain one record. + std::vector<MdnsRecord> node_records = (*it)->GetRecords(); + OSP_DCHECK(node_records.size() == 1); + MdnsRecord node_record = std::move(node_records[0]); + + if (message.CanAddRecord(node_record)) { + message.AddAnswer(std::move(node_record)); + it++; + } else if (message.questions().empty() && message.answers().empty()) { + // This case should never happen, because it means a record is too large + // to fit into its own message. + OSP_LOG << "Encountered unreasonably large message in cache. Skipping " + << "known answer in suppressions..."; + it++; + } else { + message.set_truncated(); + sender_->SendMulticast(message); + message = MdnsMessage(CreateMessageId(), MessageType::Query); + } + } sender_->SendMulticast(message); - send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); }, - now_function_() + send_delay_); + return true; +} + +void MdnsQuestionTracker::ScheduleFollowUpQuery() { + if (announcements_so_far_ >= maximum_announcement_count_) { + return; + } + announcements_so_far_++; + + send_alarm_.ScheduleFromNow( + [this] { + if (SendQuery()) { + ScheduleFollowUpQuery(); + } + }, + send_delay_); send_delay_ = send_delay_ * kIntervalIncreaseFactor; if (send_delay_ > kMaximumQueryInterval) { send_delay_ = kMaximumQueryInterval; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h index 75556f2c886..6cb863a94a1 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h @@ -6,27 +6,41 @@ #define DISCOVERY_MDNS_MDNS_TRACKERS_H_ #include <unordered_map> +#include <vector> #include "absl/hash/hash.h" #include "discovery/mdns/mdns_records.h" #include "platform/api/task_runner.h" #include "platform/base/error.h" +#include "platform/base/trivial_clock_traits.h" #include "util/alarm.h" namespace openscreen { namespace discovery { +struct Config; class MdnsRandom; class MdnsRecord; class MdnsRecordChangedCallback; class MdnsSender; // MdnsTracker is a base class for MdnsRecordTracker and MdnsQuestionTracker for -// the purposes of common code sharing only +// the purposes of common code sharing only. +// +// Instances of this class represent nodes of a bidirectional graph, such that +// if node A is adjacent to node B, B is also adjacent to A. In this class, the +// adjacent nodes are stored in adjacency list |associated_tracker_|, and +// exposed methods to add and remove nodes from this list also modify the added +// or removed node to remove this instance from its adjacency list. +// +// Because MdnsQuestionTracker::AddAssocaitedRecord() can only called on +// MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is +// only called on MdnsQuestionTracker objects, this created graph is bipartite. +// This means that MdnsRecordTracker objects are only adjacent to +// MdnsQuestionTracker objects and the opposite. class MdnsTracker { public: - using ClockNowFunctionPtr = openscreen::platform::ClockNowFunctionPtr; - using TaskRunner = openscreen::platform::TaskRunner; + enum class TrackerType { kRecordTracker, kQuestionTracker }; // MdnsTracker does not own |sender|, |task_runner| and |random_delay| // and expects that the lifetime of these objects exceeds the lifetime of @@ -34,34 +48,73 @@ class MdnsTracker { MdnsTracker(MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay); + MdnsRandom* random_delay, + TrackerType tracker_type); MdnsTracker(const MdnsTracker& other) = delete; MdnsTracker(MdnsTracker&& other) noexcept = delete; MdnsTracker& operator=(const MdnsTracker& other) = delete; MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete; - ~MdnsTracker() = default; + virtual ~MdnsTracker(); + + // Returns the record type represented by this tracker. + TrackerType tracker_type() const { return tracker_type_; } + + // Sends a query message via MdnsSender. Returns false if a follow up query + // should NOT be scheduled and true otherwise. + virtual bool SendQuery() const = 0; + + // Returns the records currently associated with this tracker. + virtual std::vector<MdnsRecord> GetRecords() const = 0; protected: + // Schedules a repeat query to be sent out. + virtual void ScheduleFollowUpQuery() = 0; + + // These methods create a bidirectional adjacency with another node in the + // graph. + bool AddAdjacentNode(const MdnsTracker* tracker) const; + bool RemoveAdjacentNode(const MdnsTracker* tracker) const; + + const std::vector<const MdnsTracker*>& adjacent_nodes() const { + return adjacent_nodes_; + } + MdnsSender* const sender_; TaskRunner* const task_runner_; const ClockNowFunctionPtr now_function_; Alarm send_alarm_; // TODO(yakimakha): Use cancelable task when available MdnsRandom* const random_delay_; + TrackerType tracker_type_; + + private: + // These methods are used to ensure the bidirectional-ness of this graph. + void AddReverseAdjacency(const MdnsTracker* tracker) const; + void RemovedReverseAdjacency(const MdnsTracker* tracker) const; + + // Adjacency list for this graph node. + mutable std::vector<const MdnsTracker*> adjacent_nodes_; }; +class MdnsQuestionTracker; + // MdnsRecordTracker manages automatic resending of mDNS queries for // refreshing records as they reach their expiration time. class MdnsRecordTracker : public MdnsTracker { public: - using Clock = openscreen::platform::Clock; + using RecordExpiredCallback = + std::function<void(const MdnsRecordTracker*, const MdnsRecord&)>; + + // NOTE: In the case that |record| is of type NSEC, |dns_type| is expected to + // differ from |record|'s type. + MdnsRecordTracker(MdnsRecord record, + DnsType dns_type, + MdnsSender* sender, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + MdnsRandom* random_delay, + RecordExpiredCallback record_expired_callback); - MdnsRecordTracker( - MdnsRecord record, - MdnsSender* sender, - TaskRunner* task_runner, - ClockNowFunctionPtr now_function, - MdnsRandom* random_delay, - std::function<void(const MdnsRecord&)> record_expired_callback); + ~MdnsRecordTracker() override; // Possible outcomes from updating a tracked record. enum class UpdateType { @@ -77,45 +130,108 @@ class MdnsRecordTracker : public MdnsTracker { // for the current tracked record. ErrorOr<UpdateType> Update(const MdnsRecord& new_record); + // Adds or removed a question which this record answers. + bool AddAssociatedQuery(const MdnsQuestionTracker* question_tracker) const; + bool RemoveAssociatedQuery(const MdnsQuestionTracker* question_tracker) const; + // Sets record to expire after 1 seconds as per RFC 6762 void ExpireSoon(); - // Returns a reference to the tracked record. - const MdnsRecord& record() const { return record_; } + // Expires the record now + void ExpireNow(); + + // Returns true if half of the record's TTL has passed, and false otherwise. + // Half is used due to specifications in RFC 6762 section 7.1. + bool IsNearingExpiry() const; + + // Returns information about the stored record. + // + // NOTE: These methods are NOT all pass-through methods to |record_|. + // specifically, dns_type() returns the DNS Type associated with this record + // tracker, which may be different from the record type if |record_| is of + // type NSEC. To avoid this case, direct access to the underlying |record_| + // instance is not provided. + // + // In this case, creating an MdnsRecord with the below data will result in a + // runtime error due to DCHECKS and that Rdata's associated type will not + // match DnsType when |record_| is of type NSEC. Therefore, creating such + // records should be guarded by is_negative_response() checks. + const DomainName& name() const { return record_.name(); } + DnsType dns_type() const { return dns_type_; } + DnsClass dns_class() const { return record_.dns_class(); } + RecordType record_type() const { return record_.record_type(); } + std::chrono::seconds ttl() const { return record_.ttl(); } + const Rdata& rdata() const { return record_.rdata(); } + + bool is_negative_response() const { + return record_.dns_type() == DnsType::kNSEC; + } private: - void SendQuery(); + using MdnsTracker::tracker_type; + + // Needed to provide the test class access to the record stored in this + // tracker. + friend class MdnsTrackerTest; + Clock::time_point GetNextSendTime(); + // MdnsTracker overrides. + bool SendQuery() const override; + void ScheduleFollowUpQuery() override; + std::vector<MdnsRecord> GetRecords() const override; + // Stores MdnsRecord provided to Start method call. MdnsRecord record_; + + // DnsType this record tracker represents. This may not match the type of + // |record_| if it is an NSEC record. + const DnsType dns_type_; + // A point in time when the record was received and the tracking has started. Clock::time_point start_time_; + // Number of times record refresh has been attempted. size_t attempt_count_ = 0; - std::function<void(const MdnsRecord&)> record_expired_callback_; + RecordExpiredCallback record_expired_callback_; }; // MdnsQuestionTracker manages automatic resending of mDNS queries for -// continuous monitoring with exponential back-off as described in RFC 6762 +// continuous monitoring with exponential back-off as described in RFC 6762. class MdnsQuestionTracker : public MdnsTracker { public: - using Clock = openscreen::platform::Clock; + // Supported query types, per RFC 6762 section 5. + enum class QueryType { kOneShot, kContinuous }; MdnsQuestionTracker(MdnsQuestion question, MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay); + MdnsRandom* random_delay, + const Config& config, + QueryType query_type = QueryType::kContinuous); + + ~MdnsQuestionTracker() override; + + // Adds or removed an answer to a the question posed by this tracker. + bool AddAssociatedRecord(const MdnsRecordTracker* record_tracker) const; + bool RemoveAssociatedRecord(const MdnsRecordTracker* record_tracker) const; // Returns a reference to the tracked question. const MdnsQuestion& question() const { return question_; } private: + using MdnsTracker::tracker_type; + using RecordKey = std::tuple<DomainName, DnsType, DnsClass>; - // Sends a query message via MdnsSender and schedules the next resend. - void SendQuery(); + // Determines if all answers to this query have been received. + bool HasReceivedAllResponses(); + + // MdnsTracker overrides. + bool SendQuery() const override; + void ScheduleFollowUpQuery() override; + std::vector<MdnsRecord> GetRecords() const override; // Stores MdnsQuestion provided to Start method call. MdnsQuestion question_; @@ -123,15 +239,18 @@ class MdnsQuestionTracker : public MdnsTracker { // A delay between the currently scheduled and the next queries. Clock::duration send_delay_; - // Active record trackers, uniquely identified by domain name, DNS record type - // and DNS record class. MdnsRecordTracker instances are stored as unique_ptr - // so they are not moved around in memory when the collection is modified. - // This allows passing a pointer to MdnsRecordTracker to a task running on the - // TaskRunner. - std::unordered_map<RecordKey, - std::unique_ptr<MdnsRecordTracker>, - absl::Hash<RecordKey>> - record_trackers_; + // Last time that this tracker's question was asked. + mutable TrivialClockTraits::time_point last_send_time_; + + // Specifies whether this query is intended to be a one-shot query, as defined + // in RFC 6762 section 5.1. + const QueryType query_type_; + + // Signifies the maximum number of times a record should be announced. + int maximum_announcement_count_; + + // Number of times this query has been announced. + int announcements_so_far_ = 0; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc index 91b33678b91..399e0d46145 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc @@ -4,6 +4,10 @@ #include "discovery/mdns/mdns_trackers.h" +#include <memory> +#include <utility> + +#include "discovery/common/config.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_sender.h" @@ -11,51 +15,49 @@ #include "gtest/gtest.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" +#include "platform/test/fake_udp_socket.h" namespace openscreen { namespace discovery { +namespace { + +constexpr Clock::duration kOneSecond = + std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)); + +} -using openscreen::platform::Clock; -using openscreen::platform::FakeClock; -using openscreen::platform::FakeTaskRunner; -using openscreen::platform::NetworkInterfaceIndex; -using openscreen::platform::TaskRunner; -using openscreen::platform::UdpSocket; using testing::_; using testing::Args; +using testing::DoAll; using testing::Invoke; using testing::Return; +using testing::StrictMock; using testing::WithArgs; ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) { const uint8_t* actual_data = reinterpret_cast<const uint8_t*>(arg0); const size_t actual_size = arg1; - EXPECT_EQ(actual_size, expected_size); + ASSERT_EQ(actual_size, expected_size); // Start at bytes[2] to skip a generated message ID. for (size_t i = 2; i < actual_size; ++i) { EXPECT_EQ(actual_data[i], expected_data[i]); } } -class MockUdpSocket : public UdpSocket { +ACTION_P(VerifyTruncated, is_truncated) { + EXPECT_EQ(arg0.is_truncated(), is_truncated); +} + +ACTION_P(VerifyRecordCount, record_count) { + EXPECT_EQ(arg0.answers().size(), static_cast<size_t>(record_count)); +} + +class MockMdnsSender : public MdnsSender { public: - MOCK_METHOD(bool, IsIPv4, (), (const, override)); - MOCK_METHOD(bool, IsIPv6, (), (const, override)); - MOCK_METHOD(IPEndpoint, GetLocalEndpoint, (), (const, override)); - MOCK_METHOD(void, Bind, (), (override)); - MOCK_METHOD(void, - SetMulticastOutboundInterface, - (NetworkInterfaceIndex), - (override)); - MOCK_METHOD(void, - JoinMulticastGroup, - (const IPAddress&, NetworkInterfaceIndex), - (override)); - MOCK_METHOD(void, - SendMessage, - (const void*, size_t, const IPEndpoint&), - (override)); - MOCK_METHOD(void, SetDscp, (DscpMode), (override)); + explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} + + MOCK_METHOD1(SendMulticast, Error(const MdnsMessage&)); + MOCK_METHOD2(SendMessage, Error(const MdnsMessage&, const IPEndpoint&)); }; class MockRecordChangedCallback : public MdnsRecordChangedCallback { @@ -71,6 +73,7 @@ class MdnsTrackerTest : public testing::Test { MdnsTrackerTest() : clock_(Clock::now()), task_runner_(&clock_), + socket_(&task_runner_), sender_(&socket_), a_question_(DomainName{"testing", "local"}, DnsType::kANY, @@ -81,31 +84,64 @@ class MdnsTrackerTest : public testing::Test { DnsClass::kIN, RecordType::kShared, std::chrono::seconds(120), - ARecordRdata(IPAddress{172, 0, 0, 1})) {} + ARecordRdata(IPAddress{172, 0, 0, 1})), + nsec_record_( + DomainName{"testing", "local"}, + DnsType::kNSEC, + DnsClass::kIN, + RecordType::kShared, + std::chrono::seconds(120), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {} template <class TrackerType> void TrackerNoQueryAfterDestruction(TrackerType tracker) { tracker.reset(); - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0); // Advance fake clock by a long time interval to make sure if there's a // scheduled task, it will run. clock_.Advance(std::chrono::hours(1)); } std::unique_ptr<MdnsRecordTracker> CreateRecordTracker( - const MdnsRecord& record) { + const MdnsRecord& record, + DnsType type) { return std::make_unique<MdnsRecordTracker>( - record, &sender_, &task_runner_, &FakeClock::now, &random_, - [this](const MdnsRecord& record) { expiration_called_ = true; }); + record, type, &sender_, &task_runner_, &FakeClock::now, &random_, + [this](const MdnsRecordTracker* tracker, const MdnsRecord& record) { + expiration_called_ = true; + }); + } + + std::unique_ptr<MdnsRecordTracker> CreateRecordTracker( + const MdnsRecord& record) { + return CreateRecordTracker(record, record.dns_type()); } std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker( - const MdnsQuestion& question) { - return std::make_unique<MdnsQuestionTracker>( - question, &sender_, &task_runner_, &FakeClock::now, &random_); + const MdnsQuestion& question, + MdnsQuestionTracker::QueryType query_type = + MdnsQuestionTracker::QueryType::kContinuous) { + return std::make_unique<MdnsQuestionTracker>(question, &sender_, + &task_runner_, &FakeClock::now, + &random_, config_, query_type); } protected: + void AdvanceThroughAllTtlFractions(std::chrono::seconds ttl) { + constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00}; + Clock::duration time_passed{0}; + for (double fraction : kTtlFractions) { + Clock::duration time_till_refresh = + std::chrono::duration_cast<Clock::duration>(ttl * fraction); + Clock::duration delta = time_till_refresh - time_passed; + time_passed = time_till_refresh; + clock_.Advance(delta); + } + } + + const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) { + return tracker->record_; + } + // clang-format off const std::vector<uint8_t> kQuestionQueryBytes = { 0x00, 0x00, // ID = 0 @@ -138,14 +174,16 @@ class MdnsTrackerTest : public testing::Test { }; // clang-format on + Config config_; FakeClock clock_; FakeTaskRunner task_runner_; - MockUdpSocket socket_; - MdnsSender sender_; + FakeUdpSocket socket_; + StrictMock<MockMdnsSender> sender_; MdnsRandom random_; MdnsQuestion a_question_; MdnsRecord a_record_; + MdnsRecord nsec_record_; bool expiration_called_ = false; }; @@ -160,32 +198,58 @@ class MdnsTrackerTest : public testing::Test { TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) { std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); - EXPECT_EQ(tracker->record(), a_record_); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); } -TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) { +TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) { + std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker( + a_question_, MdnsQuestionTracker::QueryType::kOneShot); + std::unique_ptr<MdnsQuestionTracker> question2 = CreateQuestionTracker( + a_question_, MdnsQuestionTracker::QueryType::kOneShot); + EXPECT_CALL(sender_, SendMulticast(_)).Times(2); + clock_.Advance(kOneSecond); + clock_.Advance(kOneSecond); + testing::Mock::VerifyAndClearExpectations(&sender_); + std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); - // Only expect 4 queries being sent, when record reaches it's TTL it's - // considered expired and another query is not sent - constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00}; - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(4); - - Clock::duration time_passed{0}; - for (double fraction : kTtlFractions) { - Clock::duration time_till_refresh = - std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * fraction); - Clock::duration delta = time_till_refresh - time_passed; - time_passed = time_till_refresh; - clock_.Advance(delta); - } + + // No queries without an associated tracker. + AdvanceThroughAllTtlFractions(a_record_.ttl()); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // 4 queries with one associated tracker. + tracker = CreateRecordTracker(a_record_); + tracker->AddAssociatedQuery(question.get()); + EXPECT_CALL(sender_, SendMulticast(_)).Times(4); + AdvanceThroughAllTtlFractions(a_record_.ttl()); + testing::Mock::VerifyAndClearExpectations(&sender_); + + // 8 queries with two associated trackers. + tracker = CreateRecordTracker(a_record_); + tracker->AddAssociatedQuery(question.get()); + tracker->AddAssociatedQuery(question2.get()); + EXPECT_CALL(sender_, SendMulticast(_)).Times(8); + AdvanceThroughAllTtlFractions(a_record_.ttl()); } TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) { + std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker( + a_question_, MdnsQuestionTracker::QueryType::kOneShot); + EXPECT_CALL(sender_, SendMulticast(_)).Times(1); + clock_.Advance(kOneSecond); + clock_.Advance(kOneSecond); + testing::Mock::VerifyAndClearExpectations(&sender_); + std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); + tracker->AddAssociatedQuery(question.get()); - EXPECT_CALL(socket_, SendMessage(_, _, _)) - .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId( - kRecordQueryBytes.data(), kRecordQueryBytes.size()))); + EXPECT_CALL(sender_, SendMulticast(_)) + .Times(1) + .WillRepeatedly([this](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{1}); + EXPECT_EQ(message.questions()[0], a_question_); + return Error::None(); + }); clock_.Advance( std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.83)); @@ -202,7 +266,6 @@ TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) { // no query and instead the record will expire. // Check lower bound for task being late (TTL) and an arbitrarily long time // interval to ensure the query is not sent a later time. - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0); clock_.Advance(a_record_.ttl()); clock_.Advance(std::chrono::hours(1)); } @@ -232,6 +295,16 @@ TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) { EXPECT_TRUE(expiration_called_); } +TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) { + expiration_called_ = false; + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + tracker->ExpireSoon(); + // Expire schedules expiration after 1 second. + clock_.Advance(std::chrono::seconds(1)); + EXPECT_TRUE(expiration_called_); +} + TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) { expiration_called_ = false; std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); @@ -250,9 +323,6 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) { EXPECT_EQ(tracker->Update(goodbye_record).value(), MdnsRecordTracker::UpdateType::kGoodbye); - // No refresh queries are sent after goodbye record is received. - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0); - // Advance clock to just before the expiration time of 1 second. clock_.Advance(std::chrono::microseconds{999999}); EXPECT_FALSE(expiration_called_); @@ -261,7 +331,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) { EXPECT_TRUE(expiration_called_); } -TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) { +TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) { std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(), @@ -292,6 +362,71 @@ TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) { Error::Code::kParameterInvalid); } +TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(a_record_, DnsType::kA); + auto result = tracker->Update(nsec_record_); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); + + // Check invalid update. + MdnsRecord non_a_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); + tracker = CreateRecordTracker(a_record_, DnsType::kA); + auto response = tracker->Update(non_a_nsec_record); + ASSERT_TRUE(response.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); +} + +TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord multiple_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA, + DnsType::kAAAA)); + auto result = tracker->Update(multiple_nsec_record); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record); + + // Check invalid update. + tracker = CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord non_a_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); + auto response = tracker->Update(non_a_nsec_record); + EXPECT_TRUE(response.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); +} + +TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + auto result = tracker->Update(a_record_); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); + + // Check invalid update. + tracker = CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA, + a_record_.dns_class(), a_record_.record_type(), + std::chrono::seconds{0}, + AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1})); + result = tracker->Update(aaaa_record); + EXPECT_TRUE(result.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); +} + TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) { expiration_called_ = false; std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); @@ -316,12 +451,16 @@ TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) { std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker(a_question_); - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce( + DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None()))); clock_.Advance(std::chrono::milliseconds(120)); std::chrono::seconds interval{1}; while (interval < std::chrono::hours(1)) { - EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce( + DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None()))); clock_.Advance(interval); interval *= 2; } @@ -331,9 +470,15 @@ TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) { std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker(a_question_); - EXPECT_CALL(socket_, SendMessage(_, _, _)) - .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId( - kQuestionQueryBytes.data(), kQuestionQueryBytes.size()))); + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce(DoAll( + WithArgs<0>(VerifyTruncated(false)), + [this](const MdnsMessage& message) -> Error { + EXPECT_EQ(message.questions().size(), size_t{1}); + EXPECT_EQ(message.questions()[0], a_question_); + return Error::None(); + }, + Return(Error::None()))); clock_.Advance(std::chrono::milliseconds(120)); } @@ -344,5 +489,30 @@ TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) { TrackerNoQueryAfterDestruction(std::move(tracker)); } +TEST_F(MdnsTrackerTest, QuestionTrackerSendsMultipleMessages) { + std::unique_ptr<MdnsQuestionTracker> tracker = + CreateQuestionTracker(a_question_); + + std::vector<std::unique_ptr<MdnsRecordTracker>> answers; + for (int i = 0; i < 100; i++) { + auto record = CreateRecordTracker(a_record_); + tracker->AddAssociatedRecord(record.get()); + answers.push_back(std::move(record)); + } + + EXPECT_CALL(sender_, SendMulticast(_)) + .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)), + WithArgs<0>(VerifyRecordCount(49)), + Return(Error::None()))) + .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)), + WithArgs<0>(VerifyRecordCount(50)), + Return(Error::None()))) + .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(false)), + WithArgs<0>(VerifyRecordCount(1)), + Return(Error::None()))); + + clock_.Advance(std::chrono::milliseconds(120)); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc index 58037f7e9a9..fa79dc5d44c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc @@ -6,6 +6,7 @@ #include "absl/hash/hash.h" #include "absl/strings/ascii.h" +#include "util/hashing.h" #include "util/logging.h" namespace openscreen { @@ -14,25 +15,14 @@ namespace discovery { namespace { std::vector<uint64_t> ComputeDomainNameSubhashes(const DomainName& name) { - // Based on absl Hash128to64 that combines two 64-bit hashes into one - auto hash_combiner = [](uint64_t seed, const std::string& value) -> uint64_t { - static const uint64_t kMultiplier = UINT64_C(0x9ddfea08eb382d69); - const uint64_t hash_value = absl::Hash<std::string>{}(value); - uint64_t a = (hash_value ^ seed) * kMultiplier; - a ^= (a >> 47); - uint64_t b = (seed ^ a) * kMultiplier; - b ^= (b >> 47); - b *= kMultiplier; - return b; - }; - const std::vector<std::string>& labels = name.labels(); // Use a large prime between 2^63 and 2^64 as a starting value. // This is taken from absl::Hash implementation. uint64_t hash_value = UINT64_C(0xc3a5c85c97cb3127); std::vector<uint64_t> subhashes(labels.size()); for (size_t i = labels.size(); i-- > 0;) { - hash_value = hash_combiner(hash_value, absl::AsciiStrToLower(labels[i])); + hash_value = + ComputeAggregateHash(hash_value, absl::AsciiStrToLower(labels[i])); subhashes[i] = hash_value; } return subhashes; @@ -201,6 +191,17 @@ bool MdnsWriter::Write(const TxtRecordRdata& rdata) { return true; } +bool MdnsWriter::Write(const NsecRecordRdata& rdata) { + Cursor cursor(this); + if (Skip(sizeof(uint16_t)) && Write(rdata.next_domain_name()) && + Write(rdata.encoded_types()) && + UpdateRecordLength(current(), cursor.origin())) { + cursor.Commit(); + return true; + } + return false; +} + bool MdnsWriter::Write(const MdnsRecord& record) { Cursor cursor(this); if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) && @@ -229,7 +230,7 @@ bool MdnsWriter::Write(const MdnsMessage& message) { Cursor cursor(this); Header header; header.id = message.id(); - header.flags = MakeFlags(message.type()); + header.flags = MakeFlags(message.type(), message.is_truncated()); header.question_count = message.questions().size(); header.answer_count = message.answers().size(); header.authority_record_count = message.authority_records().size(); diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h index 66c1b036126..3b8a3b08026 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h @@ -13,7 +13,7 @@ namespace openscreen { namespace discovery { -class MdnsWriter : public openscreen::BigEndianWriter { +class MdnsWriter : public BigEndianWriter { public: using BigEndianWriter::BigEndianWriter; using BigEndianWriter::Write; @@ -31,6 +31,7 @@ class MdnsWriter : public openscreen::BigEndianWriter { bool Write(const AAAARecordRdata& rdata); bool Write(const PtrRecordRdata& rdata); bool Write(const TxtRecordRdata& rdata); + bool Write(const NsecRecordRdata& rdata); // Writes a DNS resource record with its RDATA. // The correct type of RDATA to be written is contained in the type // specified in the record. @@ -59,7 +60,7 @@ class MdnsWriter : public openscreen::BigEndianWriter { // Domain name compression dictionary. // Maps hashes of previously written domain (sub)names - // to the label pointers of the first occurences in the underlying buffer. + // to the label pointers of the first occurrences in the underlying buffer. // Compression of multiple domain names is supported on the same instance of // the MdnsWriter. Underlying buffer may contain other data in addition to the // domain names. The compression dictionary persists between calls to diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc index ad0c2464060..03ddbe8e50f 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc @@ -5,6 +5,7 @@ #include "discovery/mdns/mdns_writer.h" #include <memory> +#include <vector> #include "discovery/mdns/testing/mdns_test_util.h" #include "gmock/gmock.h" @@ -244,15 +245,55 @@ TEST(MdnsWriterTest, WriteAAAARecordRdata) { TEST(MdnsWriterTest, WriteAAAARecordRdata_InsufficientBuffer) { // clang-format off - constexpr uint8_t kAAAARdata[] = { + constexpr uint16_t kAAAARdata[] = { // ADDRESS = FE80:0000:0000:0000:0202:B3FF:FE1E:8329 - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + 0xfe80, 0x0000, 0x0000, 0x0000, + 0x0202, 0xb3ff, 0xfe1e, 0x8329, }; // clang-format on TestWriteEntryInsufficientBuffer(AAAARecordRdata(IPAddress(kAAAARdata))); } +TEST(MdnsWriterTest, WriteNSECRecordRdata) { + const DomainName domain{"testing", "local"}; + NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA, + DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC); + + // clang-format off + constexpr uint8_t kExpectedRdata[] = { + 0x00, 0x20, // RDLENGTH = 32 + 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e', + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + // It takes 8 bytes to encode the kA and kSRV records because: + // - Both record types have value less than 256, so they are both in window + // block 1. + // - The bitmap length for this block is always a single byte + // - DnsTypes have the following values: + // - kA = 1 (encoded in byte 1) + // kTXT = 16 (encoded in byte 3) + // - kSRV = 33 (encoded in byte 5) + // - kNSEC = 47 (encoded in 6 bytes) + // - The largest of these is 47, so 6 bytes are needed to encode this data. + // So the full encoded version is: + // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001 + // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 | + 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01 + }; + // clang-format on + TestWriteEntrySucceeds( + NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA, + DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC), + kExpectedRdata, sizeof(kExpectedRdata)); +} + +TEST(MdnsWriterTest, WriteNSECRecordRdata_InsufficientBuffer) { + TestWriteEntryInsufficientBuffer( + NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA, + DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC)); +} + TEST(MdnsWriterTest, WritePtrRecordRdata) { // clang-format off constexpr uint8_t kExpectedRdata[] = { diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h index 30847e42643..e80e68268b8 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h @@ -16,6 +16,13 @@ #include <stddef.h> #include <stdint.h> +#include <array> +#include <chrono> + +#include "platform/api/time.h" +#include "platform/base/ip_address.h" +#include "util/logging.h" + namespace openscreen { namespace discovery { @@ -28,18 +35,25 @@ namespace discovery { // RFC 5771: https://www.ietf.org/rfc/rfc5771.txt // RFC 7346: https://www.ietf.org/rfc/rfc7346.txt -// IPv4 group address for joining mDNS multicast group, given as byte array in -// network-order. This is a link-local multicast address, so messages will not -// be forwarded outside local network. See RFC 6762, section 3. -constexpr uint8_t kDefaultMulticastGroupIPv4[4] = {224, 0, 0, 251}; +// Default multicast port used by mDNS protocol. On some systems there may be +// multiple processes binding to same port, so prefer to allow address re-use. +// See RFC 6762, Section 2 +constexpr uint16_t kDefaultMulticastPort = 5353; -// IPv6 group address for joining mDNS multicast group, given as byte array in +// IPv4 group address for joining mDNS multicast group, given as byte array in // network-order. This is a link-local multicast address, so messages will not // be forwarded outside local network. See RFC 6762, section 3. -constexpr uint8_t kDefaultMulticastGroupIPv6[16] = { - 0xFF, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFB, +const IPAddress kDefaultMulticastGroupIPv4{224, 0, 0, 251}; +const IPEndpoint kDefaultMulticastGroupIPv4Endpoint{{}, kDefaultMulticastPort}; + +// IPv6 group address for joining mDNS multicast group. This is a link-local +// multicast address, so messages will not be forwarded outside local network. +// See RFC 6762, section 3. +const IPAddress kDefaultMulticastGroupIPv6{ + 0xFF02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00FB, }; +const IPEndpoint kDefaultMulticastGroupIPv6Endpoint{{0, 0, 0, 0, 0, 0, 0, 0}, + kDefaultMulticastPort}; // IPv4 group address for joining cast-specific site-local mDNS multicast group, // given as byte array in network-order. This is a site-local multicast address, @@ -57,7 +71,9 @@ constexpr uint8_t kDefaultMulticastGroupIPv6[16] = { // NOTE: For now the group address is the same group address used for SSDP // discovery, albeit using the MDNS port rather than SSDP port. -constexpr uint8_t kDefaultSiteLocalGroupIPv4[4] = {239, 255, 255, 250}; +const IPAddress kDefaultSiteLocalGroupIPv4{239, 255, 255, 250}; +const IPEndpoint kDefaultSiteLocalGroupIPv4Endpoint{kDefaultSiteLocalGroupIPv4, + kDefaultMulticastPort}; // IPv6 group address for joining cast-specific site-local mDNS multicast group, // give as byte array in network-order. See comments for IPv4 group address for @@ -65,15 +81,11 @@ constexpr uint8_t kDefaultSiteLocalGroupIPv4[4] = {239, 255, 255, 250}; // 0xFF05 is site-local. See RFC 7346. // FF0X:0:0:0:0:0:0:C is variable scope multicast addresses for SSDP. See // https://www.iana.org/assignments/ipv6-multicast-addresses -constexpr uint8_t kDefaultSiteLocalGroupIPv6[16] = { - 0xFF, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, +const IPAddress kDefaultSiteLocalGroupIPv6{ + 0xFF05, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x000C, }; - -// Default multicast port used by mDNS protocol. On some systems there may be -// multiple processes binding to same port, so prefer to allow address re-use. -// See RFC 6762, Section 2 -constexpr uint16_t kDefaultMulticastPort = 5353; +const IPEndpoint kDefaultSiteLocalGroupIPv6Endpoint{kDefaultSiteLocalGroupIPv6, + kDefaultMulticastPort}; // Maximum MTU size (1500) minus the UDP header size (8) and IP header size // (20). If any packets are larger than this size, the responder or sender @@ -149,9 +161,18 @@ constexpr MessageType GetMessageType(uint16_t flags) { return (flags & kFlagResponse) ? MessageType::Response : MessageType::Query; } -constexpr uint16_t MakeFlags(MessageType type) { +constexpr bool IsMessageTruncated(uint16_t flags) { + return flags & kFlagTC; +} + +constexpr uint16_t MakeFlags(MessageType type, bool is_truncated) { // RFC 6762 Section 18.2 and Section 18.4 - return (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0; + uint16_t flags = + (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0; + if (is_truncated) { + flags |= kFlagTC; + } + return flags; } constexpr bool IsValidFlagsSection(uint16_t flags) { @@ -282,9 +303,36 @@ enum class DnsType : uint16_t { kTXT = 16, kAAAA = 28, kSRV = 33, + kNSEC = 47, kANY = 255, // Only allowed for QTYPE }; +inline std::ostream& operator<<(std::ostream& output, DnsType type) { + switch (type) { + case DnsType::kA: + return output << "A"; + case DnsType::kPTR: + return output << "PTR"; + case DnsType::kTXT: + return output << "TXT"; + case DnsType::kAAAA: + return output << "AAAA"; + case DnsType::kSRV: + return output << "SRV"; + case DnsType::kNSEC: + return output << "NSEC"; + case DnsType::kANY: + return output << "ANY"; + } + + OSP_NOTREACHED(); + return output; +} + +constexpr std::array<DnsType, 7> kSupportedDnsTypes = { + DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, + DnsType::kSRV, DnsType::kNSEC, DnsType::kANY}; + enum class DnsClass : uint16_t { kIN = 1, kANY = 255, // Only allowed for QCLASS @@ -305,6 +353,14 @@ enum class ResponseType { kUnicast = 1, }; +// These are the default TTL values for supported DNS Record types as specified +// by RFC 6762 section 10. +constexpr std::chrono::seconds kPtrRecordTtl(120); +constexpr std::chrono::seconds kSrvRecordTtl(120); +constexpr std::chrono::seconds kARecordTtl(120); +constexpr std::chrono::seconds kAAAARecordTtl(120); +constexpr std::chrono::seconds kTXTRecordTtl(120); + // DNS CLASS masks and values. // In mDNS the most significant bit of the RRCLASS for response records is // designated as the "cache-flush bit", as described in @@ -366,6 +422,19 @@ constexpr size_t kTXTMaxEntrySize = 255; // See RFC: https://tools.ietf.org/html/rfc6763#section-6.1 constexpr uint8_t kTXTEmptyRdata = 0; +// ============================================================================ +// Probing Constants +// ============================================================================ + +// RFC 6762 section 8.1 specifies that a probe should wait 250 ms between +// subsequent probe queries. +constexpr Clock::duration kDelayBetweenProbeQueries = + std::chrono::duration_cast<Clock::duration>(std::chrono::milliseconds{250}); + +// RFC 6762 section 8.1 specifies that the probing phase should send out probe +// requests 3 times before treating the probe as completed. +constexpr int kProbeIterationCountBeforeSuccess = 3; + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc new file mode 100644 index 00000000000..89c079b0190 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc @@ -0,0 +1,15 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/mdns/public/mdns_service.h" + +namespace openscreen { +namespace discovery { + +MdnsService::MdnsService() = default; + +MdnsService::~MdnsService() = default; + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h index 817e57dc87d..b32deb780e7 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h @@ -5,44 +5,107 @@ #ifndef DISCOVERY_MDNS_PUBLIC_MDNS_SERVICE_H_ #define DISCOVERY_MDNS_PUBLIC_MDNS_SERVICE_H_ +#include <functional> #include <memory> #include "discovery/mdns/public/mdns_constants.h" +#include "platform/base/error.h" +#include "platform/base/interface_info.h" +#include "platform/base/ip_address.h" namespace openscreen { -struct IPEndpoint; class TaskRunner; namespace discovery { +struct Config; class DomainName; +class MdnsDomainConfirmedProvider; class MdnsRecord; class MdnsRecordChangedCallback; +class ReportingClient; class MdnsService { public: - virtual ~MdnsService() = default; + enum SupportedNetworkAddressFamily : uint8_t { + kNoAddressFamily = 0, + kUseIpV4Multicast = 0x01 << 0, + kUseIpV6Multicast = 0x01 << 1 + }; + + MdnsService(); + virtual ~MdnsService(); // Creates a new MdnsService instance, to be owned by the caller. On failure, - // returns nullptr. - static std::unique_ptr<MdnsService> Create(TaskRunner* task_runner); + // returns nullptr. |task_runner|, |reporting_client|, and |config| must exist + // for the duration of the resulting instance's life. + static std::unique_ptr<MdnsService> Create( + TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config, + NetworkInterfaceIndex network_interface, + SupportedNetworkAddressFamily supported_address_types); + // Starts an mDNS query with the given properties. Updated records are passed + // to |callback|. The caller must ensure |callback| remains alive while it is + // registered with a query. virtual void StartQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) = 0; + // Stops an mDNS query with the given properties. |callback| must be the same + // callback pointer that was previously passed to StartQuery. virtual void StopQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, MdnsRecordChangedCallback* callback) = 0; - virtual void RegisterRecord(const MdnsRecord& record) = 0; + // Re-initializes the process of service discovery for the provided domain + // name. All ongoing queries for this domain are restarted and any previously + // received query results are discarded. + virtual void ReinitializeQueries(const DomainName& name) = 0; + + // Starts probing for a valid domain name based on the given one. |callback| + // will be called once a valid domain is found, and the instance must persist + // until that call is received. + virtual Error StartProbe(MdnsDomainConfirmedProvider* callback, + DomainName requested_name, + IPAddress address) = 0; + + // Registers a new mDNS record for advertisement by this service. For A, AAAA, + // SRV, and TXT records, the domain name must have already been claimed by the + // ClaimExclusiveOwnership() method and for PTR records the name being pointed + // to must have been claimed in the same fashion, but the domain name in the + // top-level MdnsRecord entity does not. + virtual Error RegisterRecord(const MdnsRecord& record) = 0; - virtual void DeregisterRecord(const MdnsRecord& record) = 0; + // Updates the existing record with name matching the name of the new record. + // NOTE: This method is not valid for PTR records. + virtual Error UpdateRegisteredRecord(const MdnsRecord& old_record, + const MdnsRecord& new_record) = 0; + + // Stops advertising the provided record. If no more records with the provided + // name are bing advertised after this call's completion, then ownership of + // the name is released. + virtual Error UnregisterRecord(const MdnsRecord& record) = 0; }; +inline MdnsService::SupportedNetworkAddressFamily operator&( + MdnsService::SupportedNetworkAddressFamily lhs, + MdnsService::SupportedNetworkAddressFamily rhs) { + return static_cast<MdnsService::SupportedNetworkAddressFamily>( + static_cast<uint8_t>(lhs) & static_cast<uint8_t>(rhs)); +} + +inline MdnsService::SupportedNetworkAddressFamily operator|( + MdnsService::SupportedNetworkAddressFamily lhs, + MdnsService::SupportedNetworkAddressFamily rhs) { + return static_cast<MdnsService::SupportedNetworkAddressFamily>( + static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs)); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS b/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS new file mode 100644 index 00000000000..efbb593a6d5 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS @@ -0,0 +1,5 @@ +# -*- Mode: Python; -*- + +include_rules = [ + '+discovery/mdns', +] diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc index 6d99fced114..6b154174a62 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc @@ -18,5 +18,37 @@ TxtRecordRdata MakeTxtRecord(std::initializer_list<absl::string_view> strings) { return TxtRecordRdata(std::move(texts)); } +MdnsRecord GetFakePtrRecord(const DomainName& target, + std::chrono::seconds ttl) { + DomainName name(++target.labels().begin(), target.labels().end()); + PtrRecordRdata rdata(target); + return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN, + RecordType::kShared, ttl, rdata); +} + +MdnsRecord GetFakeSrvRecord(const DomainName& name, std::chrono::seconds ttl) { + SrvRecordRdata rdata(0, 0, 80, name); + return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique, + ttl, rdata); +} + +MdnsRecord GetFakeTxtRecord(const DomainName& name, std::chrono::seconds ttl) { + TxtRecordRdata rdata; + return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique, + ttl, rdata); +} + +MdnsRecord GetFakeARecord(const DomainName& name, std::chrono::seconds ttl) { + ARecordRdata rdata(IPAddress(192, 168, 0, 0)); + return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique, ttl, + rdata); +} + +MdnsRecord GetFakeAAAARecord(const DomainName& name, std::chrono::seconds ttl) { + AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8)); + return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, + ttl, rdata); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h index 0ccf3eeba71..419e1f06d9b 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h @@ -15,6 +15,19 @@ namespace discovery { TxtRecordRdata MakeTxtRecord(std::initializer_list<absl::string_view> strings); +// Methods to create fake MdnsRecord entities for use in UnitTests. +MdnsRecord GetFakePtrRecord(const DomainName& target, + std::chrono::seconds ttl = std::chrono::seconds(1)); +MdnsRecord GetFakeSrvRecord(const DomainName& name, + std::chrono::seconds ttl = std::chrono::seconds(1)); +MdnsRecord GetFakeTxtRecord(const DomainName& name, + std::chrono::seconds ttl = std::chrono::seconds(1)); +MdnsRecord GetFakeARecord(const DomainName& name, + std::chrono::seconds ttl = std::chrono::seconds(1)); +MdnsRecord GetFakeAAAARecord( + const DomainName& name, + std::chrono::seconds ttl = std::chrono::seconds(1)); + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/public/DEPS b/chromium/third_party/openscreen/src/discovery/public/DEPS new file mode 100644 index 00000000000..eb84cf6bcff --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/public/DEPS @@ -0,0 +1,5 @@ +# -*- Mode: Python; -*- + +include_rules = [ + '+discovery/dnssd/public', +] diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h new file mode 100644 index 00000000000..e2dcb8361ab --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h @@ -0,0 +1,28 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_ +#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_ + +#include "discovery/dnssd/public/dns_sd_service.h" +#include "util/serial_delete_ptr.h" + +namespace openscreen { + +class TaskRunner; + +namespace discovery { + +struct Config; +class ReportingClient; + +SerialDeletePtr<DnsSdService> CreateDnsSdService( + TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config); + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_ diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h new file mode 100644 index 00000000000..f94d4e4bc4e --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h @@ -0,0 +1,93 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_ +#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_ + +#include <string> + +#include "discovery/dnssd/public/dns_sd_instance_record.h" +#include "discovery/dnssd/public/dns_sd_publisher.h" +#include "discovery/dnssd/public/dns_sd_service.h" +#include "platform/base/error.h" +#include "util/logging.h" + +namespace openscreen { +namespace discovery { + +// This class represents a top-level discovery API which sits on top of DNS-SD. +// The main purpose of this class is to hide DNS-SD internals from embedders who +// do not care about the specific functionality and do not need to understand +// DNS-SD Internals. +// T is the service-specific type which stores information regarding a specific +// service instance. +// NOTE: This class is not thread-safe and calls will be made to DnsSdService in +// the same sequence and on the same threads from which these methods are +// called. This is to avoid forcing design decisions on embedders who write +// their own implementations of the DNS-SD layer. +template <typename T> +class DnsSdServicePublisher : public DnsSdPublisher::Client { + public: + // This function type is responsible for converting from a T type to a + // DNS service instance (to publish to the network). + using ServiceInstanceConverter = std::function<DnsSdInstanceRecord(const T&)>; + + DnsSdServicePublisher(DnsSdService* service, + std::string service_name, + ServiceInstanceConverter conversion) + : conversion_(conversion), + service_name_(std::move(service_name)), + publisher_(service ? service->GetPublisher() : nullptr) { + OSP_DCHECK(publisher_); + } + + ~DnsSdServicePublisher() = default; + + Error Register(const T& instance) { + if (!instance.IsValid()) { + return Error::Code::kParameterInvalid; + } + + DnsSdInstanceRecord record = conversion_(instance); + return publisher_->Register(record, this); + } + + Error UpdateRegistration(const T& instance) { + if (!instance.IsValid()) { + return Error::Code::kParameterInvalid; + } + + DnsSdInstanceRecord record = conversion_(instance); + return publisher_->UpdateRegistration(record); + } + + ErrorOr<int> DeregisterAll() { + return publisher_->DeregisterAll(service_name_); + } + + protected: + // DnsSdPublisher::Client overrides. + // + // Embedders who care about the instance id with which the service was + // published may override this method. + void OnInstanceClaimed(const DnsSdInstanceRecord& requested_record, + const DnsSdInstanceRecord& claimed_record) { + OSP_DVLOG << "Instance ID '" << claimed_record.instance_id() + << "' claimed for requested ID '" + << requested_record.instance_id() << "'"; + OnInstanceClaimed(requested_record.instance_id()); + } + + virtual void OnInstanceClaimed(const std::string& requested_instance_id) {} + + private: + ServiceInstanceConverter conversion_; + std::string service_name_; + DnsSdPublisher* const publisher_; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_ diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h new file mode 100644 index 00000000000..56c9a51b321 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h @@ -0,0 +1,215 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_ +#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "discovery/dnssd/public/dns_sd_instance_record.h" +#include "discovery/dnssd/public/dns_sd_querier.h" +#include "discovery/dnssd/public/dns_sd_service.h" +#include "platform/base/error.h" +#include "util/logging.h" + +namespace openscreen { +namespace discovery { +namespace { + +// NOTE: Must be inlined to avoid compilation failure for unused function when +// DLOGs are disabled. +template <typename T> +inline std::string GetInstanceNames( + const std::unordered_map<std::string, T>& map) { + std::string s; + auto it = map.begin(); + if (it == map.end()) { + return s; + } + + s += it->first; + while (++it != map.end()) { + s += ", " + it->first; + } + return s; +} + +} // namespace + +// This class represents a top-level discovery API which sits on top of DNS-SD. +// T is the service-specific type which stores information regarding a specific +// service instance. +// TODO(rwkeane): Include reporting client as ctor parameter once parallel CLs +// are in. +// NOTE: This class is not thread-safe and calls will be made to DnsSdService in +// the same sequence and on the same threads from which these methods are +// called. This is to avoid forcing design decisions on embedders who write +// their own implementations of the DNS-SD layer. +template <typename T> +class DnsSdServiceWatcher : public DnsSdQuerier::Callback { + public: + using ConstRefT = std::reference_wrapper<const T>; + + // The method which will be called when any new service instance is + // discovered, a service instance changes its data (such as TXT or A data), or + // a previously discovered service instance ceases to be available. The vector + // is the set of all currently active service instances which have been + // discovered so far. + using ServicesUpdatedCallback = + std::function<void(std::vector<ConstRefT> services)>; + + // This function type is responsible for converting from a DNS service + // instance (received from another mDNS endpoint) to a T type to be returned + // to the caller. + using ServiceConverter = + std::function<ErrorOr<T>(const DnsSdInstanceRecord&)>; + + DnsSdServiceWatcher(DnsSdService* service, + std::string service_name, + ServiceConverter conversion, + ServicesUpdatedCallback callback) + : conversion_(conversion), + service_name_(std::move(service_name)), + callback_(std::move(callback)), + querier_(service ? service->GetQuerier() : nullptr) { + OSP_DCHECK(querier_); + } + + ~DnsSdServiceWatcher() = default; + + // Starts service discovery. + void StartDiscovery() { + OSP_DCHECK(!is_running_); + is_running_ = true; + + querier_->StartQuery(service_name_, this); + } + + // Stops service discovery. + void StopDiscovery() { + OSP_DCHECK(is_running_); + is_running_ = false; + + querier_->StopQuery(service_name_, this); + } + + // Returns whether or not discovery is currently ongoing. + bool is_running() const { return is_running_; } + + // Re-initializes the process of service discovery, even if the underlying + // implementation would not normally do so at this time. All previously + // received service data is discarded. + // NOTE: This call will return an error if StartDiscovery has not yet been + // called. + Error ForceRefresh() { + if (!is_running_) { + return Error::Code::kOperationInvalid; + } + + querier_->ReinitializeQueries(service_name_); + records_.clear(); + return Error::None(); + } + + // Re-initializes the process of service discovery, even if the underlying + // implementation would not normally do so at this time. All previously + // received service data is persisted. + // NOTE: This call will return an error if StartDiscovery has not yet been + // called. + Error DiscoverNow() { + if (!is_running_) { + return Error::Code::kOperationInvalid; + } + + querier_->ReinitializeQueries(service_name_); + return Error::None(); + } + + // Returns the set of services which have been received so far. + std::vector<ConstRefT> GetServices() const { + std::vector<ConstRefT> refs; + for (const auto& pair : records_) { + refs.push_back(*pair.second.get()); + } + + OSP_DVLOG << "Currently " << records_.size() + << " known service instances: [" << GetInstanceNames(records_) + << "]"; + + return refs; + } + + private: + friend class TestServiceWatcher; + + // DnsSdQuerier::Callback overrides. + void OnInstanceCreated(const DnsSdInstanceRecord& new_record) override { + // NOTE: existence is not checked because records may be overwritten after + // querier_->ReinitializeQueries() is called. + ErrorOr<T> record = conversion_(new_record); + if (record.is_error()) { + OSP_LOG << "Conversion of received record failed with error: " + << record.error(); + return; + } + records_[new_record.instance_id()] = + std::make_unique<T>(std::move(record.value())); + callback_(GetServices()); + } + + void OnInstanceUpdated(const DnsSdInstanceRecord& modified_record) override { + auto it = records_.find(modified_record.instance_id()); + if (it != records_.end()) { + ErrorOr<T> record = conversion_(modified_record); + if (record.is_error()) { + OSP_LOG << "Conversion of received record failed with error: " + << record.error(); + return; + } + auto ptr = std::make_unique<T>(std::move(record.value())); + it->second.swap(ptr); + + callback_(GetServices()); + } else { + OSP_LOG << "Received modified record for non-existent DNS-SD Instance " + << modified_record.instance_id(); + } + } + + void OnInstanceDeleted(const DnsSdInstanceRecord& old_record) override { + if (records_.erase(old_record.instance_id())) { + callback_(GetServices()); + } else { + OSP_LOG << "Received deletion of record for non-existent DNS-SD Instance " + << old_record.instance_id(); + } + } + + // Set of all instance ids found so far, mapped to the T type that it + // represents. unique_ptr<T> entities are used so that the const refs returned + // from GetServices() and the ServicesUpdatedCallback can persist even once + // this map is resized. NOTE: Unordered map is used because this set is in + // many cases expected to be large. + std::unordered_map<std::string, std::unique_ptr<T>> records_; + + // Represents whether discovery is currently running or not. + bool is_running_ = false; + + // Converts from the DNS-SD representation of a service to the outside + // representation. + ServiceConverter conversion_; + + std::string service_name_; + ServicesUpdatedCallback callback_; + DnsSdQuerier* const querier_; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_ diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc new file mode 100644 index 00000000000..a22558f5cd7 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc @@ -0,0 +1,335 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/public/dns_sd_service_watcher.h" + +#include <algorithm> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::ContainerEq; +using testing::IsSubsetOf; +using testing::IsSupersetOf; +using testing::StrictMock; + +namespace openscreen { +namespace discovery { +namespace { + +std::vector<std::string> ConvertRefs( + const std::vector<std::reference_wrapper<const std::string>>& value) { + std::vector<std::string> strings; + + // This loop is required to unwrap reference_wrapper objects. + for (const std::string& val : value) { + strings.push_back(val); + } + return strings; +} + +static const IPAddress kAddressV4(192, 168, 0, 0); +static const IPEndpoint kEndpointV4{kAddressV4, 0}; +static const std::string kCastServiceId = "_googlecast._tcp"; +static const std::string kCastDomainId = "local"; + +class MockDnsSdService : public DnsSdService { + public: + MockDnsSdService() : querier_(this) {} + + DnsSdQuerier* GetQuerier() override { return &querier_; } + DnsSdPublisher* GetPublisher() override { return nullptr; } + + MOCK_METHOD2(StartQuery, + void(const std::string& service, DnsSdQuerier::Callback* cb)); + MOCK_METHOD2(StopQuery, + void(const std::string& service, DnsSdQuerier::Callback* cb)); + MOCK_METHOD1(ReinitializeQueries, void(const std::string& service)); + + private: + class MockQuerier : public DnsSdQuerier { + public: + explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) { + OSP_DCHECK(service); + } + + void StartQuery(const std::string& service, + DnsSdQuerier::Callback* cb) override { + mock_service_->StartQuery(service, cb); + } + + void StopQuery(const std::string& service, + DnsSdQuerier::Callback* cb) override { + mock_service_->StopQuery(service, cb); + } + + void ReinitializeQueries(const std::string& service) override { + mock_service_->ReinitializeQueries(service); + } + + private: + MockDnsSdService* const mock_service_; + }; + + MockQuerier querier_; +}; + +} // namespace + +class TestServiceWatcher : public DnsSdServiceWatcher<std::string> { + public: + using DnsSdServiceWatcher<std::string>::ConstRefT; + + explicit TestServiceWatcher(MockDnsSdService* service) + : DnsSdServiceWatcher<std::string>( + service, + kCastServiceId, + [this](const DnsSdInstanceRecord& record) { + return Convert(record); + }, + [this](std::vector<ConstRefT> ref) { Callback(std::move(ref)); }) {} + + MOCK_METHOD1(Callback, void(std::vector<ConstRefT>)); + + using DnsSdServiceWatcher<std::string>::OnInstanceCreated; + using DnsSdServiceWatcher<std::string>::OnInstanceUpdated; + using DnsSdServiceWatcher<std::string>::OnInstanceDeleted; + + private: + std::string Convert(const DnsSdInstanceRecord& record) { + return record.instance_id(); + } +}; + +class DnsSdServiceWatcherTests : public testing::Test { + public: + DnsSdServiceWatcherTests() : watcher_(&service_) { + // Start service discovery, since all other tests need it + EXPECT_FALSE(watcher_.is_running()); + EXPECT_CALL(service_, StartQuery(kCastServiceId, _)); + watcher_.StartDiscovery(); + testing::Mock::VerifyAndClearExpectations(&service_); + } + + protected: + void CreateNewInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + std::vector<std::string> callbacked_services; + EXPECT_CALL(watcher_, Callback(_)) + .WillOnce([services = &callbacked_services]( + std::vector<TestServiceWatcher::ConstRefT> value) { + *services = ConvertRefs(value); + }); + watcher_.OnInstanceCreated(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count + 1); + + EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); + EXPECT_THAT(fetched_services, IsSupersetOf(services_before)); + } + + void CreateExistingInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + std::vector<std::string> callbacked_services; + EXPECT_CALL(watcher_, Callback(_)) + .WillOnce([services = &callbacked_services]( + std::vector<TestServiceWatcher::ConstRefT> value) { + *services = ConvertRefs(value); + }); + watcher_.OnInstanceCreated(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + const std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count); + + EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); + EXPECT_THAT(fetched_services, ContainerEq(services_before)); + } + + void UpdateExistingInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + std::vector<std::string> callbacked_services; + EXPECT_CALL(watcher_, Callback(_)) + .WillOnce([services = &callbacked_services]( + std::vector<TestServiceWatcher::ConstRefT> value) { + *services = ConvertRefs(value); + }); + watcher_.OnInstanceUpdated(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + const std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count); + + EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); + EXPECT_THAT(fetched_services, ContainerEq(services_before)); + } + + void DeleteExistingInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + std::vector<std::string> callbacked_services; + EXPECT_CALL(watcher_, Callback(_)) + .WillOnce([services = &callbacked_services]( + std::vector<TestServiceWatcher::ConstRefT> value) { + *services = ConvertRefs(value); + }); + watcher_.OnInstanceDeleted(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + const std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count - 1); + } + + void UpdateNonExistingInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + EXPECT_CALL(watcher_, Callback(_)).Times(0); + watcher_.OnInstanceUpdated(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + const std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count); + + EXPECT_THAT(services_before, ContainerEq(fetched_services)); + } + + void DeleteNonExistingInstance(const DnsSdInstanceRecord& record) { + const std::vector<std::string> services_before = + ConvertRefs(watcher_.GetServices()); + const size_t count = services_before.size(); + + EXPECT_CALL(watcher_, Callback(_)).Times(0); + watcher_.OnInstanceDeleted(record); + testing::Mock::VerifyAndClearExpectations(&watcher_); + + const std::vector<std::string> fetched_services = + ConvertRefs(watcher_.GetServices()); + EXPECT_EQ(fetched_services.size(), count); + + EXPECT_THAT(services_before, ContainerEq(fetched_services)); + } + + bool ContainsService(const DnsSdInstanceRecord& record) { + const std::string& service = record.instance_id(); + const std::vector<TestServiceWatcher::ConstRefT> services = + watcher_.GetServices(); + return std::find_if(services.begin(), services.end(), + [&service](const std::string& ref) { + return service == ref; + }) != services.end(); + } + + StrictMock<MockDnsSdService> service_; + StrictMock<TestServiceWatcher> watcher_; + std::vector<std::string> fetched_services; +}; + +TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) { + EXPECT_TRUE(watcher_.is_running()); + EXPECT_CALL(service_, StopQuery(kCastServiceId, _)); + watcher_.StopDiscovery(); + EXPECT_FALSE(watcher_.is_running()); +} + +TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) { + StrictMock<MockDnsSdService> service; + StrictMock<TestServiceWatcher> watcher(&service); + EXPECT_FALSE(watcher.DiscoverNow().ok()); + EXPECT_FALSE(watcher.ForceRefresh().ok()); +} + +TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) { + const DnsSdInstanceRecord record("Instance", kCastServiceId, kCastDomainId, + kEndpointV4, DnsSdTxtRecord{}); + CreateNewInstance(record); + + // Refresh services. + EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); + EXPECT_TRUE(watcher_.DiscoverNow().ok()); + EXPECT_EQ(watcher_.GetServices().size(), size_t{1}); + testing::Mock::VerifyAndClearExpectations(&service_); + + EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); + EXPECT_TRUE(watcher_.ForceRefresh().ok()); + EXPECT_EQ(watcher_.GetServices().size(), size_t{0}); + testing::Mock::VerifyAndClearExpectations(&service_); +} + +TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) { + const DnsSdInstanceRecord record("Instance", kCastServiceId, kCastDomainId, + kEndpointV4, DnsSdTxtRecord{}); + const DnsSdInstanceRecord record2("Instance2", kCastServiceId, kCastDomainId, + kEndpointV4, DnsSdTxtRecord{}); + + EXPECT_FALSE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); + + CreateNewInstance(record); + EXPECT_TRUE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); + + CreateExistingInstance(record); + EXPECT_TRUE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); + + UpdateNonExistingInstance(record2); + EXPECT_TRUE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); + + DeleteNonExistingInstance(record2); + EXPECT_TRUE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); + + CreateNewInstance(record2); + EXPECT_TRUE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + UpdateExistingInstance(record2); + EXPECT_TRUE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + UpdateExistingInstance(record); + EXPECT_TRUE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + DeleteExistingInstance(record); + EXPECT_FALSE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + UpdateNonExistingInstance(record); + EXPECT_FALSE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + DeleteNonExistingInstance(record); + EXPECT_FALSE(ContainsService(record)); + EXPECT_TRUE(ContainsService(record2)); + + DeleteExistingInstance(record2); + EXPECT_FALSE(ContainsService(record)); + EXPECT_FALSE(ContainsService(record2)); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/docs/style_guide.md b/chromium/third_party/openscreen/src/docs/style_guide.md index 85e550e7af2..08a807c8ae0 100644 --- a/chromium/third_party/openscreen/src/docs/style_guide.md +++ b/chromium/third_party/openscreen/src/docs/style_guide.md @@ -9,10 +9,23 @@ C++14 language and library features are allowed in the Open Screen Library according to the [C++14 use in Chromium](https://chromium-cpp.appspot.com#core-whitelist) guidelines. +## Modifications to the Chromium C++ Guidelines + +- `<functional>` and `std::function` objects are allowed. +- `<chrono>` is allowed and encouraged for representation of time. +- Abseil types are allowed based on the whitelist in [DEPS](https://chromium.googlesource.com/openscreen/+/refs/heads/master/DEPS). +- **Do not** use Abseil types in public APIs. +- `<thread>` and `<mutex>` are allowed, but discouraged from general use as the + library only needs to handle threading in very specific places; + see [threading.md](threading.md). + ## Open Screen Library Features - For public API functions that return values or errors, please return - [`ErrorOr<T>`](https://chromium.googlesource.com/openscreen/+/master/base/error.h). + [`ErrorOr<T>`](https://chromium.googlesource.com/openscreen/+/master/platform/base/error.h). +- In the implementation of public APIs invoked by the embedder, use + `OSP_DCHECK(TaskRunner::IsRunningOnTaskRunner())` to catch thread safety + problems early. ## Style Addenda diff --git a/chromium/third_party/openscreen/src/docs/threading.md b/chromium/third_party/openscreen/src/docs/threading.md new file mode 100644 index 00000000000..78a55b2d7d8 --- /dev/null +++ b/chromium/third_party/openscreen/src/docs/threading.md @@ -0,0 +1,21 @@ +# Threading + +The Open Screen Library is **single-threaded**; all of its code is intended to be +run on a single sequence, with a few exceptions noted below. + +A library client **must** invoke all library APIs on the same sequence that is +used to run tasks on the client's +[TaskRunner implementation](https://chromium.googlesource.com/openscreen/+/refs/heads/master/platform/api/task_runner.h). + +## Exceptions + +* The [trace logging](trace_logging.md) framework is thread-safe. +* The TaskRunner itself is thread-safe. +* The [POSIX platform implementation](https://chromium.googlesource.com/openscreen/+/refs/heads/master/platform/impl/) + starts a network thread, and handles interactions between that thread and the + TaskRunner internally. + + + + + diff --git a/chromium/third_party/openscreen/src/docs/trace_logging.md b/chromium/third_party/openscreen/src/docs/trace_logging.md index 0139b3f2348..8bb60915656 100644 --- a/chromium/third_party/openscreen/src/docs/trace_logging.md +++ b/chromium/third_party/openscreen/src/docs/trace_logging.md @@ -83,31 +83,27 @@ call. As with scoped traces, the result must be some Error::Code enum value. ## Tracing Functions All of the below functions rely on the Platform Layer's IsTraceLoggingEnabled() function. When logging is disabled, either for the specific category of trace -logging which the Macro specifies or for TraceCategory::Any in all other caes, +logging which the Macro specifies or for TraceCategory::Any in all other cases, the below functions will be treated as a NoOp. ### Synchronous Tracing - `TRACE_SCOPED(category, name)` - If logging is enabled for the provided category, this function will trace - the current function until the current scope ends with name as provided. - When this call is used, the Trace ID Hierarchy will be determined - automatically and the caller does not need to worry about it and, as such, - **this call should be used in the majority of synchronous tracing cases**. - - `TRACE_SCOPED(category, name, traceId, parentId, rootId)` - If logging is enabled for the provided category, this function will trace - the current function until the current scope ends with name as provided. The - Trace ID used for tracing this function will be set to the one provided, as - will the parent and root ids. Each of Trace ID, Parent ID, and Root ID is - optional, so providing only a subset of these values is also valid if the - caller only desires to set specific ones. - - `TRACE_SCOPED(category, name, traceIdHierarchy)` - This call is intended for use in conjunction with the TRACE_HIERARCHY macro - (as described below). If logging is enabled for the provided category, this - function will trace the current function until the current scope ends with - name as provided. The Trace ID Hierarchy will be set as provided in the - provided Trace ID Hierarchy parameter. + ```c++ + TRACE_SCOPED(category, name) + TRACE_SCOPED(category, name, traceId, parentId, rootId) + TRACE_SCOPED(category, name, traceIdHierarchy) + ``` + If logging is enabled for the provided |category|, trace the current scope. The scope + should encompass the operation described by |name|. The latter two uses of this macro are + for manually providing the trace ID hierarchy; the first auto-generates a new trace ID for + this scope and sets its parent trace ID to that of the encompassing scope (if any). + + ```c++ + TRACE_DEFAULT_SCOPED(category) + TRACE_DEFAULT_SCOPED(category, traceId, parentId, rootId) + TRACE_DEFAULT_SCOPED(category, traceIdHierarchy) + ``` + Same as TRACE_SCOPED(), but use the current function/method signature as the operation + name. ### Asynchronous Tracing `TRACE_ASYNC_START(category, name)` @@ -229,7 +225,7 @@ For an embedder to create a custom TraceLogging implementation: called frequently (especially `IsLoggingEnabled(TraceCategory)`) and are often in the critical execution path of the library's code. -2. *Call `openscreen::platform::StartTracing()` and `StopTracing()`* +2. *Call `openscreen::StartTracing()` and `StopTracing()`* These activate/deactivate tracing by providing the TraceLoggingPlatform instance and later clearing references to it. diff --git a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg index 7f3792bd8ad..d51fcc00bee 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg @@ -32,21 +32,20 @@ config_groups { name: "openscreen/try/linux64_tsan" } builders { + name: "openscreen/try/linux_arm64_debug" + } + builders { name: "openscreen/try/mac_debug" } builders { name: "openscreen/try/openscreen_presubmit" } - # Chromium bots are declared "experimental" but at 100% so they always run - # but aren't considered part of the commit queue pass/fail. builders { name: "openscreen/try/chromium_linux64_debug" - experiment_percentage: 100 } builders { name: "openscreen/try/chromium_mac_debug" - experiment_percentage: 100 } retry_config { diff --git a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg index 0b35cbf4365..447571db82f 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg @@ -74,7 +74,22 @@ builder_mixins { builder_mixins { name: "mac" + + # NOTE: The OS version here will determine which version of XCode is being + # used. Relevant links; so you and I never have to spend hours finding this + # stuff all over again to fix things like https://crbug.com/openscreen/86: + # + # 1. The recipe code that uses the "osx_sdk" recipe module: + # + # https://cs.chromium.org/chromium/build/scripts/slave/recipes + # /openscreen.py?rcl=671f9f1c5f5bef81d0a39973aa8729cc83bb290e&l=74 + # + # 2. The XCode version look-up table in the "osx_sdk" recipe module: + # + # https://cs.chromium.org/chromium/tools/depot_tools/recipes/recipe_modules + # /osx_sdk/api.py?rcl=fe18a43d590a5eac0d58e7e555b024746ba290ad&l=26 dimensions: "os:Mac-10.13" + caches: { # Cache for mac_toolchain tool and XCode.app used in recipes. name: "osx_sdk" @@ -91,6 +106,14 @@ builder_mixins { } builder_mixins { + name: "arm64" + dimensions: "cpu:x86-64" + recipe { + properties: "target_cpu:arm64" + } +} + +builder_mixins { name: "chromium" recipe: { name: "chromium" @@ -137,6 +160,13 @@ buckets { } builders { + name: "linux_arm64_debug" + mixins: "linux" + mixins: "arm64" + mixins: "debug" + } + + builders { name: "mac_debug" mixins: "mac" mixins: "debug" @@ -201,6 +231,13 @@ buckets: { } builders { + name: "linux_arm64_debug" + mixins: "linux" + mixins: "arm64" + mixins: "debug" + } + + builders { name: "mac_debug" mixins: "mac" mixins: "debug" @@ -212,6 +249,7 @@ buckets: { recipe { name: "run_presubmit" properties: "repo_name:openscreen" + properties: "runhooks:true" } mixins: "linux" mixins: "x64" diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg index 458a1c80c79..b50a49ce08a 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg @@ -26,6 +26,12 @@ consoles { } builders { + name: "buildbucket/luci.openscreen.ci/linux_arm64_debug" + category: "linux|arm64" + short_name: "arm64" + } + + builders { name: "buildbucket/luci.openscreen.ci/mac_debug" category: "mac" short_name: "dbg" @@ -70,8 +76,26 @@ consoles { } builders { + name: "buildbucket/luci.openscreen.ci/linux_arm64_debug" + category: "linux|arm64" + short_name: "arm64" + } + + builders { name: "buildbucket/luci.openscreen.try/mac_debug" category: "mac" short_name: "dbg" } + + builders { + name: "buildbucket/luci.openscreen.try/chromium_linux64_debug" + category: "chromium fyi" + short_name: "linux64" + } + + builders { + name: "buildbucket/luci.openscreen.try/chromium_mac_debug" + category: "chromium fyi" + short_name: "mac" + } } diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg index 4c2822dc4c3..39efee3f7b7 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg @@ -25,6 +25,7 @@ trigger { triggers: "linux64_debug" triggers: "linux64_gcc_debug" triggers: "linux64_tsan" + triggers: "linux_arm64_debug" triggers: "mac_debug" } @@ -71,6 +72,16 @@ job { } job { + id: "linux_arm64_debug" + acl_sets: "default" + buildbucket: { + server: "cr-buildbucket.appspot.com" + bucket: "luci.openscreen.ci" + builder: "linux64_arm64_debug" + } +} + +job { id: "mac_debug" acl_sets: "default" buildbucket: { diff --git a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc index 37c6bb432fa..daa5fa0c596 100644 --- a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc +++ b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc @@ -36,9 +36,6 @@ #include "third_party/tinycbor/src/src/cbor.h" #include "util/trace_logging.h" -using openscreen::platform::Clock; -using openscreen::platform::PlatformClientPosix; - namespace { const char* kReceiverLogFilename = "_recv_fifo"; @@ -431,8 +428,7 @@ void ListenerDemo() { listener_config, &listener_observer, PlatformClientPosix::GetInstance()->GetTaskRunner()); - MessageDemuxer demuxer(platform::Clock::now, - MessageDemuxer::kDefaultBufferLimit); + MessageDemuxer demuxer(Clock::now, MessageDemuxer::kDefaultBufferLimit); DemoConnectionClientObserver client_observer; auto connection_client = ProtocolConnectionClientFactory::Create( &demuxer, &client_observer, @@ -440,7 +436,7 @@ void ListenerDemo() { auto* network_service = NetworkServiceManager::Create( std::move(mdns_listener), nullptr, std::move(connection_client), nullptr); - auto controller = std::make_unique<Controller>(platform::Clock::now); + auto controller = std::make_unique<Controller>(Clock::now); network_service->GetMdnsServiceListener()->Start(); network_service->GetProtocolConnectionClient()->Start(); @@ -524,8 +520,7 @@ void PublisherDemo(absl::string_view friendly_name) { PlatformClientPosix::GetInstance()->GetTaskRunner()); ServerConfig server_config; - for (const platform::InterfaceInfo& interface : - platform::GetNetworkInterfaces()) { + for (const InterfaceInfo& interface : GetNetworkInterfaces()) { OSP_VLOG << "Found interface: " << interface; if (!interface.addresses.empty()) { server_config.connection_endpoints.push_back( @@ -535,8 +530,7 @@ void PublisherDemo(absl::string_view friendly_name) { OSP_LOG_IF(WARN, server_config.connection_endpoints.empty()) << "No network interfaces had usable addresses for mDNS publishing."; - MessageDemuxer demuxer(platform::Clock::now, - MessageDemuxer::kDefaultBufferLimit); + MessageDemuxer demuxer(Clock::now, MessageDemuxer::kDefaultBufferLimit); DemoConnectionServerObserver server_observer; auto connection_server = ProtocolConnectionServerFactory::Create( server_config, &demuxer, &server_observer, @@ -588,7 +582,9 @@ InputArgs GetInputArgs(int argc, char** argv) { } int main(int argc, char** argv) { - using openscreen::platform::LogLevel; + using openscreen::Clock; + using openscreen::LogLevel; + using openscreen::PlatformClientPosix; std::cout << "Usage: osp_demo [-v] [friendly_name]" << std::endl << "-v: enable more verbose logging" << std::endl @@ -602,11 +598,11 @@ int main(int argc, char** argv) { const bool is_receiver_demo = !args.friendly_server_name.empty(); const char* log_filename = is_receiver_demo ? kReceiverLogFilename : kControllerLogFilename; - openscreen::platform::SetLogFifoOrDie(log_filename); + openscreen::SetLogFifoOrDie(log_filename); LogLevel level = args.is_verbose ? LogLevel::kVerbose : LogLevel::kInfo; - openscreen::platform::SetLogLevel(level); - openscreen::platform::TextTraceLoggingPlatform text_logging_platform; + openscreen::SetLogLevel(level); + openscreen::TextTraceLoggingPlatform text_logging_platform; PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}); diff --git a/chromium/third_party/openscreen/src/osp/impl/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/BUILD.gn index b446aef9bd3..779300bb483 100644 --- a/chromium/third_party/openscreen/src/osp/impl/BUILD.gn +++ b/chromium/third_party/openscreen/src/osp/impl/BUILD.gn @@ -43,6 +43,7 @@ source_set("impl") { ] deps = [ "../../platform", + "../../third_party/abseil", "../../util", "quic", ] @@ -64,6 +65,7 @@ if (use_chromium_quic) { deps = [ "../../osp/msgs", "../../platform", + "../../third_party/abseil", "../../third_party/chromium_quic", "../../util", "quic", diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn index 4779f5264c0..cd051d1750d 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn @@ -15,6 +15,7 @@ source_set("mdns_interface") { public_deps = [ "../../../../platform", + "../../../../third_party/abseil", "../../../../util", ] } diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc index 691658e0d8a..a6c0e508165 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc @@ -37,9 +37,6 @@ // shouldn't be expected to be a source of truth, nor should it be expected to // be correct after running for a long time. -using openscreen::platform::Clock; -using openscreen::platform::PlatformClientPosix; - namespace openscreen { namespace osp { namespace { @@ -59,22 +56,21 @@ struct Service { std::vector<std::string> txt; }; -class DemoSocketClient : public platform::UdpSocket::Client { +class DemoSocketClient : public UdpSocket::Client { public: DemoSocketClient(MdnsResponderAdapterImpl* mdns) : mdns_(mdns) {} - void OnError(platform::UdpSocket* socket, Error error) override { + void OnError(UdpSocket* socket, Error error) override { // TODO(crbug.com/openscreen/66): Change to OSP_LOG_FATAL. OSP_LOG_ERROR << "configuration failed for interface " << error.message(); OSP_CHECK(false); } - void OnSendError(platform::UdpSocket* socket, Error error) override { + void OnSendError(UdpSocket* socket, Error error) override { OSP_UNIMPLEMENTED(); } - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override { + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override { mdns_->OnRead(socket, std::move(packet)); } @@ -128,20 +124,20 @@ void SignalThings() { OSP_LOG << "signal handlers setup" << std::endl << "pid: " << getpid(); } -std::vector<platform::UdpSocketUniquePtr> SetUpMulticastSockets( - platform::TaskRunner* task_runner, - const std::vector<platform::NetworkInterfaceIndex>& index_list, - platform::UdpSocket::Client* client) { - std::vector<platform::UdpSocketUniquePtr> sockets; +std::vector<std::unique_ptr<UdpSocket>> SetUpMulticastSockets( + TaskRunner* task_runner, + const std::vector<NetworkInterfaceIndex>& index_list, + UdpSocket::Client* client) { + std::vector<std::unique_ptr<UdpSocket>> sockets; for (const auto ifindex : index_list) { auto create_result = - platform::UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353}); + UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353}); if (!create_result) { OSP_LOG_ERROR << "failed to create IPv4 socket for interface " << ifindex << ": " << create_result.error().message(); continue; } - platform::UdpSocketUniquePtr socket = std::move(create_result.value()); + std::unique_ptr<UdpSocket> socket = std::move(create_result.value()); socket->JoinMulticastGroup(IPAddress{224, 0, 0, 251}, ifindex); socket->SetMulticastOutboundInterface(ifindex); @@ -250,7 +246,7 @@ void HandleEvents(MdnsResponderAdapterImpl* mdns_adapter) { } } -void BrowseDemo(platform::TaskRunner* task_runner, +void BrowseDemo(TaskRunner* task_runner, const std::string& service_name, const std::string& service_protocol, const std::string& service_instance) { @@ -269,9 +265,8 @@ void BrowseDemo(platform::TaskRunner* task_runner, auto mdns_adapter = std::make_unique<MdnsResponderAdapterImpl>(); mdns_adapter->Init(); mdns_adapter->SetHostLabel("gigliorononomicon"); - const std::vector<platform::InterfaceInfo> interfaces = - platform::GetNetworkInterfaces(); - std::vector<platform::NetworkInterfaceIndex> index_list; + const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces(); + std::vector<NetworkInterfaceIndex> index_list; for (const auto& interface : interfaces) { OSP_LOG << "Found interface: " << interface; if (!interface.addresses.empty()) { @@ -290,10 +285,10 @@ void BrowseDemo(platform::TaskRunner* task_runner, // Listen on all interfaces. auto socket_it = sockets.begin(); - for (platform::NetworkInterfaceIndex index : index_list) { + for (NetworkInterfaceIndex index : index_list) { const auto& interface = *std::find_if(interfaces.begin(), interfaces.end(), - [index](const openscreen::platform::InterfaceInfo& info) { + [index](const openscreen::InterfaceInfo& info) { return info.index == index; }); // Pick any address for the given interface. @@ -308,7 +303,7 @@ void BrowseDemo(platform::TaskRunner* task_runner, {{"k1", "yurtle"}, {"k2", "turtle"}}); } - for (const platform::UdpSocketUniquePtr& socket : sockets) { + for (const std::unique_ptr<UdpSocket>& socket : sockets) { mdns_adapter->StartPtrQuery(socket.get(), service_type.value()); } @@ -332,7 +327,7 @@ void BrowseDemo(platform::TaskRunner* task_runner, for (const auto& s : *g_services) { LogService(s.second); } - for (const platform::UdpSocketUniquePtr& socket : sockets) { + for (const std::unique_ptr<UdpSocket>& socket : sockets) { mdns_adapter->DeregisterInterface(socket.get()); } mdns_adapter->Close(); @@ -343,7 +338,10 @@ void BrowseDemo(platform::TaskRunner* task_runner, } // namespace openscreen int main(int argc, char** argv) { - openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kVerbose); + using openscreen::Clock; + using openscreen::PlatformClientPosix; + + openscreen::SetLogLevel(openscreen::LogLevel::kVerbose); std::string service_instance; std::string service_type("_openscreen._udp"); diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc index 810c7bbdfce..edf25503dd1 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc @@ -9,7 +9,7 @@ namespace osp { QueryEventHeader::QueryEventHeader() = default; QueryEventHeader::QueryEventHeader(QueryEventHeader::Type response_type, - platform::UdpSocket* socket) + UdpSocket* socket) : response_type(response_type), socket(socket) {} QueryEventHeader::QueryEventHeader(QueryEventHeader&&) noexcept = default; QueryEventHeader::~QueryEventHeader() = default; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h index 376a9899fa4..66083d57a72 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h @@ -29,13 +29,13 @@ struct QueryEventHeader { }; QueryEventHeader(); - QueryEventHeader(Type response_type, platform::UdpSocket* socket); + QueryEventHeader(Type response_type, UdpSocket* socket); QueryEventHeader(QueryEventHeader&&) noexcept; ~QueryEventHeader(); QueryEventHeader& operator=(QueryEventHeader&&) noexcept; Type response_type; - platform::UdpSocket* socket; + UdpSocket* socket; }; struct PtrEvent { @@ -160,7 +160,7 @@ enum class MdnsResponderErrorCode { // called after any sequence of calls to mDNSResponder. It also returns a // timeout value, after which it must be called again (e.g. for maintaining its // cache). -class MdnsResponderAdapter : public platform::UdpSocket::Client { +class MdnsResponderAdapter : public UdpSocket::Client { public: MdnsResponderAdapter(); virtual ~MdnsResponderAdapter() = 0; @@ -184,14 +184,14 @@ class MdnsResponderAdapter : public platform::UdpSocket::Client { // mDNSResponder. |socket| will be used to identify which interface received // the data in OnDataReceived and will be used to send data via the platform // layer. - virtual Error RegisterInterface(const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& interface_address, - platform::UdpSocket* socket) = 0; - virtual Error DeregisterInterface(platform::UdpSocket* socket) = 0; + virtual Error RegisterInterface(const InterfaceInfo& interface_info, + const IPSubnet& interface_address, + UdpSocket* socket) = 0; + virtual Error DeregisterInterface(UdpSocket* socket) = 0; // Returns the time period after which this method must be called again, if // any. - virtual platform::Clock::duration RunTasks() = 0; + virtual Clock::duration RunTasks() = 0; virtual std::vector<PtrEvent> TakePtrResponses() = 0; virtual std::vector<SrvEvent> TakeSrvResponses() = 0; @@ -200,33 +200,33 @@ class MdnsResponderAdapter : public platform::UdpSocket::Client { virtual std::vector<AaaaEvent> TakeAaaaResponses() = 0; virtual MdnsResponderErrorCode StartPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) = 0; virtual MdnsResponderErrorCode StartSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) = 0; virtual MdnsResponderErrorCode StartTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket, + virtual MdnsResponderErrorCode StartAQuery(UdpSocket* socket, const DomainName& domain_name) = 0; virtual MdnsResponderErrorCode StartAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) = 0; virtual MdnsResponderErrorCode StopPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) = 0; virtual MdnsResponderErrorCode StopSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) = 0; virtual MdnsResponderErrorCode StopTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket, + virtual MdnsResponderErrorCode StopAQuery(UdpSocket* socket, const DomainName& domain_name) = 0; virtual MdnsResponderErrorCode StopAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) = 0; // The following methods concern advertising a service via mDNS. The diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc index 6b986ded342..3feb8e7aaa5 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc @@ -13,8 +13,6 @@ #include "util/logging.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace osp { namespace { @@ -206,7 +204,7 @@ Error MdnsResponderAdapterImpl::Init() { } void MdnsResponderAdapterImpl::Close() { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::Close"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::Close"); mDNS_StartExit(&mdns_); // Let all services send goodbyes. while (!service_records_.empty()) { @@ -228,7 +226,7 @@ void MdnsResponderAdapterImpl::Close() { } Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::SetHostLabel"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::SetHostLabel"); if (host_label.size() > DomainName::kDomainNameMaxLabelLength) return Error::Code::kDomainNameTooLong; @@ -242,10 +240,10 @@ Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) { } Error MdnsResponderAdapterImpl::RegisterInterface( - const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& interface_address, - platform::UdpSocket* socket) { - TRACE_SCOPED(TraceCategory::mDNS, + const InterfaceInfo& interface_info, + const IPSubnet& interface_address, + UdpSocket* socket) { + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RegisterInterface"); OSP_DCHECK(socket); @@ -272,8 +270,9 @@ Error MdnsResponderAdapterImpl::RegisterInterface( } static_assert(sizeof(info.MAC.b) == sizeof(interface_info.hardware_address), - "MAC addresss size mismatch."); - memcpy(info.MAC.b, interface_info.hardware_address, sizeof(info.MAC.b)); + "MAC address size mismatch."); + memcpy(info.MAC.b, interface_info.hardware_address.data(), + sizeof(info.MAC.b)); info.McastTxRx = 1; platform_storage_.sockets.push_back(socket); auto result = mDNS_RegisterInterface(&mdns_, &info, mDNSfalse); @@ -284,9 +283,8 @@ Error MdnsResponderAdapterImpl::RegisterInterface( : Error::Code::kMdnsRegisterFailure; } -Error MdnsResponderAdapterImpl::DeregisterInterface( - platform::UdpSocket* socket) { - TRACE_SCOPED(TraceCategory::mDNS, +Error MdnsResponderAdapterImpl::DeregisterInterface(UdpSocket* socket) { + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::DeregisterInterface"); const auto info_it = responder_interface_info_.find(socket); if (info_it == responder_interface_info_.end()) @@ -304,15 +302,14 @@ Error MdnsResponderAdapterImpl::DeregisterInterface( responder_interface_info_.erase(info_it); return Error::None(); } -void MdnsResponderAdapterImpl::OnRead( - platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet_or_error) { +void MdnsResponderAdapterImpl::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet_or_error) { if (packet_or_error.is_error()) { return; } - platform::UdpPacket packet = std::move(packet_or_error.value()); - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::OnRead"); + UdpPacket packet = std::move(packet_or_error.value()); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::OnRead"); mDNSAddr src; if (packet.source().address.IsV4()) { src.type = mDNSAddrType_IPv4; @@ -341,20 +338,18 @@ void MdnsResponderAdapterImpl::OnRead( reinterpret_cast<mDNSInterfaceID>(packet.socket())); } -void MdnsResponderAdapterImpl::OnSendError(platform::UdpSocket* socket, - Error error) { +void MdnsResponderAdapterImpl::OnSendError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } -void MdnsResponderAdapterImpl::OnError(platform::UdpSocket* socket, - Error error) { +void MdnsResponderAdapterImpl::OnError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } -platform::Clock::duration MdnsResponderAdapterImpl::RunTasks() { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::RunTasks"); +Clock::duration MdnsResponderAdapterImpl::RunTasks() { + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RunTasks"); mDNS_Execute(&mdns_); @@ -409,9 +404,9 @@ std::vector<AaaaEvent> MdnsResponderAdapterImpl::TakeAaaaResponses() { } MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartPtrQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartPtrQuery"); auto& ptr_questions = socket_to_questions_[socket].ptr; if (ptr_questions.find(service_type) != ptr_questions.end()) return MdnsResponderErrorCode::kNoError; @@ -456,9 +451,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartSrvQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartSrvQuery"); if (!service_instance.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -494,9 +489,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartTxtQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartTxtQuery"); if (!service_instance.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -532,9 +527,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartAQuery"); if (!domain_name.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -570,9 +565,10 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAaaaQuery"); + TRACE_SCOPED(TraceCategory::kMdns, + "MdnsResponderAdapterImpl::StartAaaaQuery"); if (!domain_name.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -608,9 +604,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopPtrQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopPtrQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -626,9 +622,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopSrvQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopSrvQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -644,9 +640,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopTxtQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopTxtQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -662,9 +658,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -680,9 +676,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery( } MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAaaaQuery"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAaaaQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -704,7 +700,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::RegisterService( const DomainName& target_host, uint16_t target_port, const std::map<std::string, std::string>& txt_data) { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RegisterService"); OSP_DCHECK(IsValidServiceName(service_name)); OSP_DCHECK(IsValidServiceProtocol(service_protocol)); @@ -749,7 +745,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::DeregisterService( const std::string& service_instance, const std::string& service_name, const std::string& service_protocol) { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::DeregisterService"); domainlabel instance; domainlabel name; @@ -779,7 +775,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::UpdateTxtData( const std::string& service_name, const std::string& service_protocol, const std::map<std::string, std::string>& txt_data) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::UpdateTxtData"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::UpdateTxtData"); domainlabel instance; domainlabel name; domainlabel protocol; @@ -813,7 +809,8 @@ void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::AQueryCallback"); + TRACE_SCOPED(TraceCategory::kMdns, + "MdnsResponderAdapterImpl::AQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); OSP_DCHECK_EQ(answer->rrtype, kDNSType_A); @@ -834,8 +831,8 @@ void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m, OSP_DCHECK_EQ(added, QC_addnocache); } adapter->a_responses_.emplace_back( - QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>( - answer->InterfaceID)}, + QueryEventHeader{event_type, + reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, std::move(domain), address); } @@ -844,7 +841,7 @@ void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::AaaaQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); @@ -866,8 +863,8 @@ void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m, OSP_DCHECK_EQ(added, QC_addnocache); } adapter->aaaa_responses_.emplace_back( - QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>( - answer->InterfaceID)}, + QueryEventHeader{event_type, + reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, std::move(domain), address); } @@ -876,7 +873,7 @@ void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::PtrQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); @@ -897,8 +894,8 @@ void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m, OSP_DCHECK_EQ(added, QC_addnocache); } adapter->ptr_responses_.emplace_back( - QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>( - answer->InterfaceID)}, + QueryEventHeader{event_type, + reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, std::move(result)); } @@ -907,7 +904,7 @@ void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::SrvQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); @@ -932,8 +929,8 @@ void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m, OSP_DCHECK_EQ(added, QC_addnocache); } adapter->srv_responses_.emplace_back( - QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>( - answer->InterfaceID)}, + QueryEventHeader{event_type, + reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, std::move(service), std::move(result), GetNetworkOrderPort(answer->rdata->u.srv.port)); } @@ -963,8 +960,8 @@ void MdnsResponderAdapterImpl::TxtQueryCallback(mDNS* m, OSP_DCHECK_EQ(added, QC_addnocache); } adapter->txt_responses_.emplace_back( - QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>( - answer->InterfaceID)}, + QueryEventHeader{event_type, + reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, std::move(service), std::move(lines)); } @@ -992,10 +989,10 @@ void MdnsResponderAdapterImpl::ServiceCallback(mDNS* m, } void MdnsResponderAdapterImpl::AdvertiseInterfaces() { - TRACE_SCOPED(TraceCategory::mDNS, + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::AdvertiseInterfaces"); for (auto& info : responder_interface_info_) { - platform::UdpSocket* socket = info.first; + UdpSocket* socket = info.first; NetworkInterfaceInfo& interface_info = info.second; mDNS_SetupResourceRecord(&interface_info.RR_A, /** RDataStorage */ nullptr, reinterpret_cast<mDNSInterfaceID>(socket), @@ -1027,8 +1024,7 @@ void MdnsResponderAdapterImpl::DeadvertiseInterfaces() { } } -void MdnsResponderAdapterImpl::RemoveQuestionsIfEmpty( - platform::UdpSocket* socket) { +void MdnsResponderAdapterImpl::RemoveQuestionsIfEmpty(UdpSocket* socket) { auto entry = socket_to_questions_.find(socket); bool empty = entry->second.a.empty() || entry->second.aaaa.empty() || entry->second.ptr.empty() || entry->second.srv.empty() || diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h index 1e5ddeb0e39..80669e577d0 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h @@ -29,17 +29,16 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { Error SetHostLabel(const std::string& host_label) override; - Error RegisterInterface(const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& interface_address, - platform::UdpSocket* socket) override; - Error DeregisterInterface(platform::UdpSocket* socket) override; + Error RegisterInterface(const InterfaceInfo& interface_info, + const IPSubnet& interface_address, + UdpSocket* socket) override; + Error DeregisterInterface(UdpSocket* socket) override; - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnError(platform::UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnError(UdpSocket* socket, Error error) override; - platform::Clock::duration RunTasks() override; + Clock::duration RunTasks() override; std::vector<PtrEvent> TakePtrResponses() override; std::vector<SrvEvent> TakeSrvResponses() override; @@ -47,29 +46,29 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { std::vector<AEvent> TakeAResponses() override; std::vector<AaaaEvent> TakeAaaaResponses() override; - MdnsResponderErrorCode StartPtrQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket, const DomainName& service_type) override; MdnsResponderErrorCode StartSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; MdnsResponderErrorCode StartTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; - MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartAQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StartAaaaQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StopPtrQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket, const DomainName& service_type) override; MdnsResponderErrorCode StopSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; MdnsResponderErrorCode StopTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; - MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopAQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StopAaaaQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket, const DomainName& domain_name) override; MdnsResponderErrorCode RegisterService( @@ -124,7 +123,7 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { void AdvertiseInterfaces(); void DeadvertiseInterfaces(); - void RemoveQuestionsIfEmpty(platform::UdpSocket* socket); + void RemoveQuestionsIfEmpty(UdpSocket* socket); CacheEntity rr_cache_[kRrCacheSize]; @@ -136,10 +135,9 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { // platform sockets. mDNS_PlatformSupport platform_storage_; - std::map<platform::UdpSocket*, Questions> socket_to_questions_; + std::map<UdpSocket*, Questions> socket_to_questions_; - std::map<platform::UdpSocket*, NetworkInterfaceInfo> - responder_interface_info_; + std::map<UdpSocket*, NetworkInterfaceInfo> responder_interface_info_; std::vector<AEvent> a_responses_; std::vector<AaaaEvent> aaaa_responses_; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc index 759fe44d470..29b76679665 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc @@ -49,7 +49,7 @@ TEST(MdnsResponderAdapterImplTest, ExampleData) { 'p', 5, 'l', 'o', 'c', 'a', 'l', 0}}; const IPEndpoint mdns_endpoint{{224, 0, 0, 251}, 5353}; - platform::UdpPacket packet(std::begin(data), std::end(data)); + UdpPacket packet(std::begin(data), std::end(data)); packet.set_source({{192, 168, 0, 2}, 6556}); packet.set_destination(mdns_endpoint); packet.set_socket(nullptr); diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc index 01603be4f16..4e204e37055 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc @@ -18,7 +18,6 @@ #include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h" #include "util/logging.h" -using openscreen::platform::Clock; using std::chrono::duration_cast; using std::chrono::hours; using std::chrono::milliseconds; @@ -44,8 +43,7 @@ mStatus mDNSPlatformSendUDP(const mDNS* m, UDPSocket* src, const mDNSAddr* dst, mDNSIPPort dstport) { - auto* const socket = - reinterpret_cast<openscreen::platform::UdpSocket*>(InterfaceID); + auto* const socket = reinterpret_cast<openscreen::UdpSocket*>(InterfaceID); const auto socket_it = std::find(m->p->sockets.begin(), m->p->sockets.end(), socket); if (socket_it == m->p->sockets.end()) @@ -111,6 +109,8 @@ mStatus mDNSPlatformTimeInit() { } mDNSs32 mDNSPlatformRawTime() { + using openscreen::Clock; + const Clock::time_point now = Clock::now(); // A signed 32-bit integer counting milliseconds only gives ~24.8 days of @@ -129,8 +129,7 @@ mDNSs32 mDNSPlatformRawTime() { mDNSs32 mDNSPlatformUTC() { const auto seconds_since_epoch = - duration_cast<seconds>(openscreen::platform::GetWallTimeSinceUnixEpoch()) - .count(); + duration_cast<seconds>(openscreen::GetWallTimeSinceUnixEpoch()).count(); // The return type will cause overflow in early 2038. Warn future developers // a year ahead of time. @@ -244,7 +243,11 @@ void mDNSPlatformSetDNSConfig(mDNS* const m, mDNSBool setsearch, domainname* const fqdn, DNameListElem** RegDomains, - DNameListElem** BrowseDomains) {} + DNameListElem** BrowseDomains) { + if (fqdn) { + std::memset(fqdn, 0, sizeof(*fqdn)); + } +} mStatus mDNSPlatformGetPrimaryInterface(mDNS* const m, mDNSAddr* v4, diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h index afe04cd8653..342913fe620 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h @@ -10,7 +10,7 @@ #include "platform/api/udp_socket.h" struct mDNS_PlatformSupport_struct { - std::vector<openscreen::platform::UdpSocket*> sockets; + std::vector<openscreen::UdpSocket*> sockets; }; #endif // OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_PLATFORM_H_ diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc index 983fcd7810d..5441472fbca 100644 --- a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc +++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc @@ -5,6 +5,7 @@ #include "osp/impl/internal_services.h" #include <algorithm> +#include <utility> #include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" #include "osp/impl/mdns_responder_service.h" @@ -21,8 +22,7 @@ constexpr char kServiceProtocol[] = "_udp"; const IPAddress kMulticastAddress{224, 0, 0, 251}; const IPAddress kMulticastIPv6Address{ // ff02::fb - 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, + 0xff02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00fb, }; const uint16_t kMulticastListeningPort = 5353; @@ -37,8 +37,7 @@ class MdnsResponderAdapterImplFactory final } }; -Error SetUpMulticastSocket(platform::UdpSocket* socket, - platform::NetworkInterfaceIndex ifindex) { +Error SetUpMulticastSocket(UdpSocket* socket, NetworkInterfaceIndex ifindex) { const IPAddress broadcast_address = socket->IsIPv6() ? kMulticastIPv6Address : kMulticastAddress; @@ -60,7 +59,7 @@ int g_instance_ref_count = 0; std::unique_ptr<ServiceListener> InternalServices::CreateListener( const MdnsServiceListenerConfig& config, ServiceListener::Observer* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { auto* services = ReferenceSingleton(task_runner); auto listener = std::make_unique<ServiceListenerImpl>(&services->mdns_service_); @@ -74,7 +73,7 @@ std::unique_ptr<ServiceListener> InternalServices::CreateListener( std::unique_ptr<ServicePublisher> InternalServices::CreatePublisher( const ServicePublisher::Config& config, ServicePublisher::Observer* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { auto* services = ReferenceSingleton(task_runner); services->mdns_service_.SetServiceConfig( config.hostname, config.service_instance_name, @@ -99,11 +98,10 @@ InternalServices::InternalPlatformLinkage::~InternalPlatformLinkage() { std::vector<MdnsPlatformService::BoundInterface> InternalServices::InternalPlatformLinkage::RegisterInterfaces( - const std::vector<platform::NetworkInterfaceIndex>& whitelist) { - const std::vector<platform::InterfaceInfo> interfaces = - platform::GetNetworkInterfaces(); + const std::vector<NetworkInterfaceIndex>& whitelist) { + const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces(); const bool do_filter_using_whitelist = !whitelist.empty(); - std::vector<platform::NetworkInterfaceIndex> index_list; + std::vector<NetworkInterfaceIndex> index_list; for (const auto& interface : interfaces) { OSP_VLOG << "Found interface: " << interface; if (do_filter_using_whitelist && @@ -121,28 +119,26 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces( // Set up sockets to send and listen to mDNS multicast traffic on all // interfaces. std::vector<BoundInterface> result; - for (platform::NetworkInterfaceIndex index : index_list) { - const auto& interface = - *std::find_if(interfaces.begin(), interfaces.end(), - [index](const platform::InterfaceInfo& info) { - return info.index == index; - }); + for (NetworkInterfaceIndex index : index_list) { + const auto& interface = *std::find_if( + interfaces.begin(), interfaces.end(), + [index](const InterfaceInfo& info) { return info.index == index; }); if (interface.addresses.empty()) { continue; } // Pick any address for the given interface. - const platform::IPSubnet& primary_subnet = interface.addresses.front(); + const IPSubnet& primary_subnet = interface.addresses.front(); auto create_result = - platform::UdpSocket::Create(parent_->task_runner_, parent_, - IPEndpoint{{}, kMulticastListeningPort}); + UdpSocket::Create(parent_->task_runner_, parent_, + IPEndpoint{{}, kMulticastListeningPort}); if (!create_result) { OSP_LOG_ERROR << "failed to create socket for interface " << index << ": " << create_result.error().message(); continue; } - platform::UdpSocketUniquePtr socket = std::move(create_result.value()); + std::unique_ptr<UdpSocket> socket = std::move(create_result.value()); if (!SetUpMulticastSocket(socket.get(), index).ok()) { continue; } @@ -158,21 +154,20 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces( void InternalServices::InternalPlatformLinkage::DeregisterInterfaces( const std::vector<BoundInterface>& registered_interfaces) { for (const auto& interface : registered_interfaces) { - platform::UdpSocket* const socket = interface.socket; + UdpSocket* const socket = interface.socket; parent_->DeregisterMdnsSocket(socket); - const auto it = - std::find_if(open_sockets_.begin(), open_sockets_.end(), - [socket](const platform::UdpSocketUniquePtr& s) { - return s.get() == socket; - }); + const auto it = std::find_if(open_sockets_.begin(), open_sockets_.end(), + [socket](const std::unique_ptr<UdpSocket>& s) { + return s.get() == socket; + }); OSP_DCHECK(it != open_sockets_.end()); open_sockets_.erase(it); } } -InternalServices::InternalServices(platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner) +InternalServices::InternalServices(ClockNowFunctionPtr now_function, + TaskRunner* task_runner) : mdns_service_(now_function, task_runner, kServiceName, @@ -183,23 +178,23 @@ InternalServices::InternalServices(platform::ClockNowFunctionPtr now_function, InternalServices::~InternalServices() = default; -void InternalServices::RegisterMdnsSocket(platform::UdpSocket* socket) { +void InternalServices::RegisterMdnsSocket(UdpSocket* socket) { OSP_CHECK(g_instance) << "No listener or publisher is alive."; // TODO(rwkeane): Hook this up to the new mDNS library once we swap out the // mDNSResponder. } -void InternalServices::DeregisterMdnsSocket(platform::UdpSocket* socket) { +void InternalServices::DeregisterMdnsSocket(UdpSocket* socket) { // TODO(rwkeane): Hook this up to the new mDNS library once we swap out the // mDNSResponder. } // static InternalServices* InternalServices::ReferenceSingleton( - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { if (!g_instance) { OSP_CHECK_EQ(g_instance_ref_count, 0); - g_instance = new InternalServices(&platform::Clock::now, task_runner); + g_instance = new InternalServices(&Clock::now, task_runner); } ++g_instance_ref_count; return g_instance; @@ -216,18 +211,17 @@ void InternalServices::DereferenceSingleton(void* instance) { } } -void InternalServices::OnError(platform::UdpSocket* socket, Error error) { +void InternalServices::OnError(UdpSocket* socket, Error error) { OSP_LOG_ERROR << "failed to configure socket " << error.message(); this->DeregisterMdnsSocket(socket); } -void InternalServices::OnSendError(platform::UdpSocket* socket, Error error) { +void InternalServices::OnSendError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } -void InternalServices::OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) { +void InternalServices::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) { g_instance->mdns_service_.OnRead(socket, std::move(packet)); } diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.h b/chromium/third_party/openscreen/src/osp/impl/internal_services.h index 5a1f015a030..364c5963c65 100644 --- a/chromium/third_party/openscreen/src/osp/impl/internal_services.h +++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.h @@ -25,9 +25,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -36,22 +34,21 @@ namespace osp { // event loop. // TODO(btolsch): This may be renamed and/or split up once QUIC code lands and // this use case is more concrete. -class InternalServices : platform::UdpSocket::Client { +class InternalServices : UdpSocket::Client { public: static std::unique_ptr<ServiceListener> CreateListener( const MdnsServiceListenerConfig& config, ServiceListener::Observer* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); static std::unique_ptr<ServicePublisher> CreatePublisher( const ServicePublisher::Config& config, ServicePublisher::Observer* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); // UdpSocket::Client overrides. - void OnError(platform::UdpSocket* socket, Error error) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; private: class InternalPlatformLinkage final : public MdnsPlatformService { @@ -60,31 +57,29 @@ class InternalServices : platform::UdpSocket::Client { ~InternalPlatformLinkage() override; std::vector<BoundInterface> RegisterInterfaces( - const std::vector<platform::NetworkInterfaceIndex>& whitelist) override; + const std::vector<NetworkInterfaceIndex>& whitelist) override; void DeregisterInterfaces( const std::vector<BoundInterface>& registered_interfaces) override; private: InternalServices* const parent_; - std::vector<platform::UdpSocketUniquePtr> open_sockets_; + std::vector<std::unique_ptr<UdpSocket>> open_sockets_; }; // The TaskRunner provided here should live for the duration of this // InternalService object's lifetime. - InternalServices(platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner); + InternalServices(ClockNowFunctionPtr now_function, TaskRunner* task_runner); ~InternalServices() override; - void RegisterMdnsSocket(platform::UdpSocket* socket); - void DeregisterMdnsSocket(platform::UdpSocket* socket); + void RegisterMdnsSocket(UdpSocket* socket); + void DeregisterMdnsSocket(UdpSocket* socket); - static InternalServices* ReferenceSingleton( - platform::TaskRunner* task_runner); + static InternalServices* ReferenceSingleton(TaskRunner* task_runner); static void DereferenceSingleton(void* instance); MdnsResponderService mdns_service_; - platform::TaskRunner* const task_runner_; + TaskRunner* const task_runner_; OSP_DISALLOW_COPY_AND_ASSIGN(InternalServices); }; diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc index b4e6d8b85c5..e46829b8043 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc @@ -12,9 +12,9 @@ namespace openscreen { namespace osp { MdnsPlatformService::BoundInterface::BoundInterface( - const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& subnet, - platform::UdpSocket* socket) + const InterfaceInfo& interface_info, + const IPSubnet& subnet, + UdpSocket* socket) : interface_info(interface_info), subnet(subnet), socket(socket) { OSP_DCHECK(socket); } diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h index 7cb6271fb5b..ec6ca31b481 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h @@ -16,23 +16,23 @@ namespace osp { class MdnsPlatformService { public: struct BoundInterface { - BoundInterface(const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& subnet, - platform::UdpSocket* socket); + BoundInterface(const InterfaceInfo& interface_info, + const IPSubnet& subnet, + UdpSocket* socket); ~BoundInterface(); bool operator==(const BoundInterface& other) const; bool operator!=(const BoundInterface& other) const; - platform::InterfaceInfo interface_info; - platform::IPSubnet subnet; - platform::UdpSocket* socket; + InterfaceInfo interface_info; + IPSubnet subnet; + UdpSocket* socket; }; virtual ~MdnsPlatformService() = default; virtual std::vector<BoundInterface> RegisterInterfaces( - const std::vector<platform::NetworkInterfaceIndex>& whitelist) = 0; + const std::vector<NetworkInterfaceIndex>& whitelist) = 0; virtual void DeregisterInterfaces( const std::vector<BoundInterface>& registered_interfaces) = 0; }; diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc index 347f1254269..f1aedbdcaba 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc @@ -13,8 +13,6 @@ #include "util/logging.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace osp { namespace { @@ -33,8 +31,8 @@ std::string ServiceIdFromServiceInstanceName( } // namespace MdnsResponderService::MdnsResponderService( - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + TaskRunner* task_runner, const std::string& service_name, const std::string& service_protocol, std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, @@ -51,7 +49,7 @@ void MdnsResponderService::SetServiceConfig( const std::string& hostname, const std::string& instance, uint16_t port, - const std::vector<platform::NetworkInterfaceIndex> whitelist, + const std::vector<NetworkInterfaceIndex> whitelist, const std::map<std::string, std::string>& txt_data) { OSP_DCHECK(!hostname.empty()); OSP_DCHECK(!instance.empty()); @@ -63,9 +61,9 @@ void MdnsResponderService::SetServiceConfig( service_txt_data_ = txt_data; } -void MdnsResponderService::OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::OnRead"); +void MdnsResponderService::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet) { + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::OnRead"); if (!mdns_responder_) { return; } @@ -74,12 +72,11 @@ void MdnsResponderService::OnRead(platform::UdpSocket* socket, HandleMdnsEvents(); } -void MdnsResponderService::OnSendError(platform::UdpSocket* socket, - Error error) { +void MdnsResponderService::OnSendError(UdpSocket* socket, Error error) { mdns_responder_->OnSendError(socket, std::move(error)); } -void MdnsResponderService::OnError(platform::UdpSocket* socket, Error error) { +void MdnsResponderService::OnError(UdpSocket* socket, Error error) { mdns_responder_->OnError(socket, std::move(error)); } @@ -214,7 +211,7 @@ bool MdnsResponderService::NetworkScopedDomainNameComparator::operator()( } void MdnsResponderService::HandleMdnsEvents() { - TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::HandleMdnsEvents"); + TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::HandleMdnsEvents"); // NOTE: In the common case, we will get a single combined packet for // PTR/SRV/TXT/A and then no other packets. If we don't loop here, we would // start SRV/TXT queries based on the PTR response, but never check for events @@ -334,7 +331,7 @@ void MdnsResponderService::StopListening() { } network_scoped_domain_to_host_.clear(); for (const auto& service : service_by_name_) { - platform::UdpSocket* const socket = service.second->ptr_socket; + UdpSocket* const socket = service.second->ptr_socket; mdns_responder_->StopSrvQuery(socket, service.first); mdns_responder_->StopTxtQuery(socket, service.first); } @@ -429,7 +426,7 @@ bool MdnsResponderService::HandlePtrEvent( InstanceNameSet* modified_instance_names) { bool events_possible = false; const auto& instance_name = ptr_event.service_instance; - platform::UdpSocket* const socket = ptr_event.header.socket; + UdpSocket* const socket = ptr_event.header.socket; auto entry = service_by_name_.find(ptr_event.service_instance); switch (ptr_event.header.response_type) { case QueryEventHeader::Type::kAddedNoCache: @@ -481,7 +478,7 @@ bool MdnsResponderService::HandleSrvEvent( bool events_possible = false; auto& domain_name = srv_event.domain_name; const auto& instance_name = srv_event.service_instance; - platform::UdpSocket* const socket = srv_event.header.socket; + UdpSocket* const socket = srv_event.header.socket; auto entry = service_by_name_.find(srv_event.service_instance); if (entry == service_by_name_.end()) return events_possible; @@ -568,7 +565,7 @@ bool MdnsResponderService::HandleTxtEvent( } bool MdnsResponderService::HandleAddressEvent( - platform::UdpSocket* socket, + UdpSocket* socket, QueryEventHeader::Type response_type, const DomainName& domain_name, bool a_event, @@ -619,14 +616,14 @@ bool MdnsResponderService::HandleAaaaEvent( } MdnsResponderService::HostInfo* MdnsResponderService::AddOrGetHostInfo( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { return &network_scoped_domain_to_host_[NetworkScopedDomainName{socket, domain_name}]; } MdnsResponderService::HostInfo* MdnsResponderService::GetHostInfo( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { auto kv = network_scoped_domain_to_host_.find( NetworkScopedDomainName{socket, domain_name}); @@ -642,16 +639,15 @@ bool MdnsResponderService::IsServiceReady(const ServiceInstance& instance, !instance.txt_info.empty() && (host->v4_address || host->v6_address)); } -platform::NetworkInterfaceIndex -MdnsResponderService::GetNetworkInterfaceIndexFromSocket( - const platform::UdpSocket* socket) const { +NetworkInterfaceIndex MdnsResponderService::GetNetworkInterfaceIndexFromSocket( + const UdpSocket* socket) const { auto it = std::find_if( bound_interfaces_.begin(), bound_interfaces_.end(), [socket](const MdnsPlatformService::BoundInterface& interface) { return interface.socket == socket; }); if (it == bound_interfaces_.end()) - return platform::kInvalidNetworkInterfaceIndex; + return kInvalidNetworkInterfaceIndex; return it->interface_info.index; } diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h index c6ec126f523..a010f5a6065 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h @@ -34,29 +34,27 @@ class MdnsResponderAdapterFactory { class MdnsResponderService : public ServiceListenerImpl::Delegate, public ServicePublisherImpl::Delegate, - public platform::UdpSocket::Client { + public UdpSocket::Client { public: MdnsResponderService( - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + TaskRunner* task_runner, const std::string& service_name, const std::string& service_protocol, std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, std::unique_ptr<MdnsPlatformService> platform); virtual ~MdnsResponderService() override; - void SetServiceConfig( - const std::string& hostname, - const std::string& instance, - uint16_t port, - const std::vector<platform::NetworkInterfaceIndex> whitelist, - const std::map<std::string, std::string>& txt_data); + void SetServiceConfig(const std::string& hostname, + const std::string& instance, + uint16_t port, + const std::vector<NetworkInterfaceIndex> whitelist, + const std::map<std::string, std::string>& txt_data); // UdpSocket::Client overrides. - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnError(platform::UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnError(UdpSocket* socket, Error error) override; // ServiceListenerImpl::Delegate overrides. void StartListener() override; @@ -98,7 +96,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, // NOTE: service_instance implicit in map key. struct ServiceInstance { - platform::UdpSocket* ptr_socket = nullptr; + UdpSocket* ptr_socket = nullptr; DomainName domain_name; uint16_t port = 0; bool has_ptr_record = false; @@ -116,7 +114,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, }; struct NetworkScopedDomainName { - platform::UdpSocket* socket; + UdpSocket* socket; DomainName domain_name; }; @@ -144,7 +142,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, InstanceNameSet* modified_instance_names); bool HandleTxtEvent(const TxtEvent& txt_event, InstanceNameSet* modified_instance_names); - bool HandleAddressEvent(platform::UdpSocket* socket, + bool HandleAddressEvent(UdpSocket* socket, QueryEventHeader::Type response_type, const DomainName& domain_name, bool a_event, @@ -155,13 +153,11 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, bool HandleAaaaEvent(const AaaaEvent& aaaa_event, InstanceNameSet* modified_instance_names); - HostInfo* AddOrGetHostInfo(platform::UdpSocket* socket, - const DomainName& domain_name); - HostInfo* GetHostInfo(platform::UdpSocket* socket, - const DomainName& domain_name); + HostInfo* AddOrGetHostInfo(UdpSocket* socket, const DomainName& domain_name); + HostInfo* GetHostInfo(UdpSocket* socket, const DomainName& domain_name); bool IsServiceReady(const ServiceInstance& instance, HostInfo* host) const; - platform::NetworkInterfaceIndex GetNetworkInterfaceIndexFromSocket( - const platform::UdpSocket* socket) const; + NetworkInterfaceIndex GetNetworkInterfaceIndexFromSocket( + const UdpSocket* socket) const; // Runs background tasks to manage the internal mDNS state. void RunBackgroundTasks(); @@ -175,7 +171,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, std::string service_hostname_; std::string service_instance_name_; uint16_t service_port_; - std::vector<platform::NetworkInterfaceIndex> interface_index_whitelist_; + std::vector<NetworkInterfaceIndex> interface_index_whitelist_; std::map<std::string, std::string> service_txt_data_; std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory_; @@ -199,7 +195,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate, std::map<std::string, ServiceInfo> receiver_info_; - platform::TaskRunner* const task_runner_; + TaskRunner* const task_runner_; // Scheduled to run periodic background tasks. Alarm background_tasks_alarm_; diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc index fe8e1239dc4..5a9891acf5b 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc @@ -24,12 +24,12 @@ namespace osp { class TestingMdnsResponderService final : public MdnsResponderService { public: TestingMdnsResponderService( - platform::FakeTaskRunner* task_runner, + FakeTaskRunner* task_runner, const std::string& service_name, const std::string& service_protocol, std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, std::unique_ptr<MdnsPlatformService> platform_service) - : MdnsResponderService(&platform::FakeClock::now, + : MdnsResponderService(&FakeClock::now, task_runner, service_name, service_protocol, @@ -172,10 +172,10 @@ class MockServicePublisherObserver final : public ServicePublisher::Observer { MOCK_METHOD1(OnMetrics, void(ServicePublisher::Metrics)); }; -platform::UdpSocket* const kDefaultSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(16)); -platform::UdpSocket* const kSecondSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(24)); +UdpSocket* const kDefaultSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16)); +UdpSocket* const kSecondSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24)); class MdnsResponderServiceTest : public ::testing::Test { protected: @@ -184,8 +184,8 @@ class MdnsResponderServiceTest : public ::testing::Test { std::make_unique<FakeMdnsResponderAdapterFactory>(); auto wrapper_factory = std::make_unique<WrapperMdnsResponderAdapterFactory>( mdns_responder_factory_.get()); - clock_ = std::make_unique<platform::FakeClock>(platform::Clock::now()); - task_runner_ = std::make_unique<platform::FakeTaskRunner>(clock_.get()); + clock_ = std::make_unique<FakeClock>(Clock::now()); + task_runner_ = std::make_unique<FakeTaskRunner>(clock_.get()); auto platform_service = std::make_unique<FakeMdnsPlatformService>(); fake_platform_service_ = platform_service.get(); fake_platform_service_->set_interfaces(bound_interfaces_); @@ -202,8 +202,8 @@ class MdnsResponderServiceTest : public ::testing::Test { &publisher_observer_, mdns_service_.get()); } - std::unique_ptr<platform::FakeClock> clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; MockServiceListenerObserver observer_; FakeMdnsPlatformService* fake_platform_service_; std::unique_ptr<FakeMdnsResponderAdapterFactory> mdns_responder_factory_; @@ -213,22 +213,22 @@ class MdnsResponderServiceTest : public ::testing::Test { std::unique_ptr<ServicePublisherImpl> service_publisher_; const uint8_t default_mac_[6] = {0, 11, 22, 33, 44, 55}; const uint8_t second_mac_[6] = {55, 33, 22, 33, 44, 77}; - const platform::IPSubnet default_subnet_{IPAddress{192, 168, 3, 2}, 24}; - const platform::IPSubnet second_subnet_{IPAddress{10, 0, 0, 3}, 24}; + const IPSubnet default_subnet_{IPAddress{192, 168, 3, 2}, 24}; + const IPSubnet second_subnet_{IPAddress{10, 0, 0, 3}, 24}; std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{ MdnsPlatformService::BoundInterface{ - platform::InterfaceInfo{1, - default_mac_, - "eth0", - platform::InterfaceInfo::Type::kEthernet, - {default_subnet_}}, + InterfaceInfo{1, + default_mac_, + "eth0", + InterfaceInfo::Type::kEthernet, + {default_subnet_}}, default_subnet_, kDefaultSocket}, MdnsPlatformService::BoundInterface{ - platform::InterfaceInfo{2, - second_mac_, - "eth1", - platform::InterfaceInfo::Type::kEthernet, - {second_subnet_}}, + InterfaceInfo{2, + second_mac_, + "eth1", + InterfaceInfo::Type::kEthernet, + {second_subnet_}}, second_subnet_, kSecondSocket}, }; }; @@ -284,10 +284,9 @@ TEST_F(MdnsResponderServiceTest, BasicServiceStates) { TEST_F(MdnsResponderServiceTest, NetworkNetworkInterfaceIndex) { constexpr uint8_t mac[6] = {12, 34, 56, 78, 90}; - const platform::IPSubnet subnet{IPAddress{10, 0, 0, 2}, 24}; + const IPSubnet subnet{IPAddress{10, 0, 0, 2}, 24}; bound_interfaces_.emplace_back( - platform::InterfaceInfo{ - 2, mac, "wlan0", platform::InterfaceInfo::Type::kWifi, {subnet}}, + InterfaceInfo{2, mac, "wlan0", InterfaceInfo::Type::kWifi, {subnet}}, subnet, kSecondSocket); fake_platform_service_->set_interfaces(bound_interfaces_); EXPECT_CALL(observer_, OnStarted()); diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc index a94d81081ac..cae4a341ec7 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc @@ -8,9 +8,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -18,7 +16,7 @@ namespace osp { std::unique_ptr<ServiceListener> MdnsServiceListenerFactory::Create( const MdnsServiceListenerConfig& config, ServiceListener::Observer* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { return InternalServices::CreateListener(config, observer, task_runner); } diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc index cbd8d122465..f055e772afb 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc @@ -8,9 +8,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -18,7 +16,7 @@ namespace osp { std::unique_ptr<ServicePublisher> MdnsServicePublisherFactory::Create( const ServicePublisher::Config& config, ServicePublisher::Observer* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { return InternalServices::CreatePublisher(config, observer, task_runner); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc index 1bd6f408a34..9220b63b724 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc @@ -181,13 +181,12 @@ void ConnectionManager::RemoveConnection(Connection* connection) { // TODO(jophba): refine the RegisterWatch/OnStreamMessage API. We // should add a layer between the message logic and the parse/dispatch // logic, and remove the CBOR information from ConnectionManager. -ErrorOr<size_t> ConnectionManager::OnStreamMessage( - uint64_t endpoint_id, - uint64_t connection_id, - msgs::Type message_type, - const uint8_t* buffer, - size_t buffer_size, - platform::Clock::time_point now) { +ErrorOr<size_t> ConnectionManager::OnStreamMessage(uint64_t endpoint_id, + uint64_t connection_id, + msgs::Type message_type, + const uint8_t* buffer, + size_t buffer_size, + Clock::time_point now) { switch (message_type) { case msgs::Type::kPresentationConnectionMessage: { msgs::PresentationConnectionMessage message; diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc index 6f611dda25a..f4c829a4117 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc @@ -58,12 +58,11 @@ class MockConnectRequest final class ConnectionTest : public ::testing::Test { public: ConnectionTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(std::chrono::milliseconds(1298424))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::milliseconds(1298424))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); controller_connection_manager_ = std::make_unique<ConnectionManager>( quic_bridge_->controller_demuxer.get()); receiver_connection_manager_ = std::make_unique<ConnectionManager>( @@ -89,8 +88,8 @@ class ConnectionTest : public ::testing::Test { return response; } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; std::unique_ptr<FakeQuicBridge> quic_bridge_; std::unique_ptr<ConnectionManager> controller_connection_manager_; std::unique_ptr<ConnectionManager> receiver_connection_manager_; diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc index 5a54cb2b99c..94ee4e2d1f4 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc @@ -418,7 +418,7 @@ void swap(Controller::ConnectRequest& a, Controller::ConnectRequest& b) { swap(a.controller_, b.controller_); } -Controller::Controller(platform::ClockNowFunctionPtr now_function) { +Controller::Controller(ClockNowFunctionPtr now_function) { availability_requester_ = std::make_unique<UrlAvailabilityRequester>(now_function); connection_manager_ = @@ -621,7 +621,7 @@ class Controller::TerminationListener final msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) override; + Clock::time_point now) override; private: Controller* const controller_; @@ -650,7 +650,7 @@ ErrorOr<size_t> Controller::TerminationListener::OnStreamMessage( msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) { + Clock::time_point now) { OSP_CHECK_EQ(static_cast<int>(msgs::Type::kPresentationTerminationEvent), static_cast<int>(message_type)); msgs::PresentationTerminationEvent event; diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc index 5b32c12f458..bad4c66616a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc @@ -77,12 +77,11 @@ class MockRequestDelegate final : public RequestDelegate { class ControllerTest : public ::testing::Test { public: ControllerTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(seconds(11111))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = + std::make_unique<FakeClock>(Clock::time_point(seconds(11111))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); receiver_info1 = { "service-id1", "lucas-auer", 1, quic_bridge_->kReceiverEndpoint, {}}; } @@ -94,7 +93,7 @@ class ControllerTest : public ::testing::Test { NetworkServiceManager::Create(std::move(service_listener), nullptr, std::move(quic_bridge_->quic_client), std::move(quic_bridge_->quic_server)); - controller_ = std::make_unique<Controller>(platform::FakeClock::now); + controller_ = std::make_unique<Controller>(FakeClock::now); ON_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillByDefault( Invoke([this](std::unique_ptr<ProtocolConnection>& connection) { @@ -117,16 +116,15 @@ class ControllerTest : public ::testing::Test { ssize_t decode_result = -1; msgs::Type msg_type; EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce( - Invoke([request, &msg_type, &decode_result]( - uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - msg_type = message_type; - decode_result = msgs::DecodePresentationUrlAvailabilityRequest( - buffer, buffer_size, request); - return decode_result; - })); + .WillOnce(Invoke([request, &msg_type, &decode_result]( + uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + msg_type = message_type; + decode_result = msgs::DecodePresentationUrlAvailabilityRequest( + buffer, buffer_size, request); + return decode_result; + })); quic_bridge_->RunTasksUntilIdle(); ASSERT_EQ(msg_type, msgs::Type::kPresentationUrlAvailabilityRequest); ASSERT_GT(decode_result, 0); @@ -206,16 +204,15 @@ class ControllerTest : public ::testing::Test { ssize_t decode_result = -1; msgs::Type msg_type; EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce( - Invoke([request, &msg_type, &decode_result]( - uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - msg_type = message_type; - decode_result = msgs::DecodePresentationConnectionCloseRequest( - buffer, buffer_size, request); - return decode_result; - })); + .WillOnce(Invoke([request, &msg_type, &decode_result]( + uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + msg_type = message_type; + decode_result = msgs::DecodePresentationConnectionCloseRequest( + buffer, buffer_size, request); + return decode_result; + })); connection->Close(Connection::CloseReason::kClosed); EXPECT_EQ(connection->state(), Connection::State::kClosed); quic_bridge_->RunTasksUntilIdle(); @@ -264,16 +261,15 @@ class ControllerTest : public ::testing::Test { msgs::PresentationStartRequest request; msgs::Type msg_type; EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce( - Invoke([&request, &msg_type]( - uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - msg_type = message_type; - ssize_t result = msgs::DecodePresentationStartRequest( - buffer, buffer_size, &request); - return result; - })); + .WillOnce(Invoke([&request, &msg_type]( + uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + msg_type = message_type; + ssize_t result = msgs::DecodePresentationStartRequest( + buffer, buffer_size, &request); + return result; + })); Controller::ConnectRequest connect_request = controller_->StartPresentation( "https://example.com/receiver.html", receiver_info1.service_id, &mock_request_delegate, mock_connection_delegate); @@ -297,8 +293,8 @@ class ControllerTest : public ::testing::Test { ASSERT_TRUE(*connection); } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; MessageDemuxer::MessageWatch availability_watch_; MockMessageCallback mock_callback_; std::unique_ptr<FakeQuicBridge> quic_bridge_; @@ -413,16 +409,15 @@ TEST_F(ControllerTest, TerminatePresentationFromController) { msgs::PresentationTerminationRequest termination_request; msgs::Type msg_type; EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce( - Invoke([&termination_request, &msg_type]( - uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - msg_type = message_type; - ssize_t result = msgs::DecodePresentationTerminationRequest( - buffer, buffer_size, &termination_request); - return result; - })); + .WillOnce(Invoke([&termination_request, &msg_type]( + uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + msg_type = message_type; + ssize_t result = msgs::DecodePresentationTerminationRequest( + buffer, buffer_size, &termination_request); + return result; + })); connection->Terminate(TerminationReason::kControllerTerminateCalled); quic_bridge_->RunTasksUntilIdle(); @@ -505,16 +500,15 @@ TEST_F(ControllerTest, Reconnect) { ssize_t decode_result = -1; msgs::Type msg_type; EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce( - Invoke([&open_request, &msg_type, &decode_result]( - uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - msg_type = message_type; - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &open_request); - return decode_result; - })); + .WillOnce(Invoke([&open_request, &msg_type, &decode_result]( + uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + msg_type = message_type; + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &open_request); + return decode_result; + })); quic_bridge_->RunTasksUntilIdle(); ASSERT_FALSE(connection); diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc index 17babae038b..eb60e0bb085 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc @@ -16,8 +16,6 @@ #include "util/logging.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace osp { namespace { @@ -107,11 +105,11 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) { - TRACE_SCOPED(TraceCategory::Presentation, "Receiver::OnStreamMessage"); + Clock::time_point now) { + TRACE_SCOPED(TraceCategory::kPresentation, "Receiver::OnStreamMessage"); switch (message_type) { case msgs::Type::kPresentationUrlAvailabilityRequest: { - TRACE_SCOPED(TraceCategory::Presentation, + TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationUrlAvailabilityRequest"); OSP_VLOG << "got presentation-url-availability-request"; msgs::PresentationUrlAvailabilityRequest request; @@ -137,7 +135,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, } case msgs::Type::kPresentationStartRequest: { - TRACE_SCOPED(TraceCategory::Presentation, "kPresentationStartRequest"); + TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationStartRequest"); OSP_VLOG << "got presentation-start-request"; msgs::PresentationStartRequest request; const ssize_t result = @@ -198,7 +196,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, } case msgs::Type::kPresentationConnectionOpenRequest: { - TRACE_SCOPED(TraceCategory::Presentation, + TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationConnectionOpenRequest"); OSP_VLOG << "Got a presentation-connection-open-request"; msgs::PresentationConnectionOpenRequest request; @@ -266,7 +264,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, } case msgs::Type::kPresentationTerminationRequest: { - TRACE_SCOPED(TraceCategory::Presentation, + TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationTerminationRequest"); OSP_VLOG << "got presentation-termination-request"; msgs::PresentationTerminationRequest request; diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc index 098df0d3436..173b9fd561a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc @@ -69,12 +69,11 @@ class MockReceiverDelegate final : public ReceiverDelegate { class PresentationReceiverTest : public ::testing::Test { public: PresentationReceiverTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(std::chrono::milliseconds(1298424))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::milliseconds(1298424))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); } protected: @@ -103,8 +102,8 @@ class PresentationReceiverTest : public ::testing::Test { NetworkServiceManager::Dispose(); } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; const std::string url1_{"https://www.example.com/receiver.html"}; std::unique_ptr<FakeQuicBridge> quic_bridge_; MockReceiverDelegate mock_receiver_delegate_; @@ -142,14 +141,14 @@ TEST_F(PresentationReceiverTest, QueryAvailability) { msgs::PresentationUrlAvailabilityResponse response; EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce(Invoke([&response](uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, - const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) { - ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse( - buffer, buffer_size, &response); - return result; - })); + .WillOnce( + Invoke([&response](uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse( + buffer, buffer_size, &response); + return result; + })); quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(request.request_id, response.request_id); EXPECT_EQ( @@ -190,14 +189,14 @@ TEST_F(PresentationReceiverTest, StartPresentation) { ResponseResult::kSuccess); msgs::PresentationStartResponse response; EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce(Invoke([&response](uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, - const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) { - ssize_t result = msgs::DecodePresentationStartResponse( - buffer, buffer_size, &response); - return result; - })); + .WillOnce( + Invoke([&response](uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + ssize_t result = msgs::DecodePresentationStartResponse( + buffer, buffer_size, &response); + return result; + })); quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(msgs::Result::kSuccess, response.result); EXPECT_EQ(connection.connection_id(), response.connection_id); diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc index 9213447c26e..b5e3bae7730 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc @@ -12,7 +12,6 @@ #include "osp/public/network_service_manager.h" #include "util/logging.h" -using openscreen::platform::Clock; using std::chrono::seconds; namespace openscreen { @@ -48,7 +47,7 @@ uint64_t GetNextRequestId(const uint64_t endpoint_id) { } // namespace UrlAvailabilityRequester::UrlAvailabilityRequester( - platform::ClockNowFunctionPtr now_function) + ClockNowFunctionPtr now_function) : now_function_(now_function) { OSP_DCHECK(now_function_); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h index 82f5bb6bca6..dd0309a5970 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h @@ -29,7 +29,7 @@ namespace osp { // given URL. class UrlAvailabilityRequester { public: - explicit UrlAvailabilityRequester(platform::ClockNowFunctionPtr now_function); + explicit UrlAvailabilityRequester(ClockNowFunctionPtr now_function); ~UrlAvailabilityRequester(); // Adds a persistent availability request for |urls| to all known receivers. @@ -63,7 +63,7 @@ class UrlAvailabilityRequester { // Ensures that all open availability watches (to all receivers) that are // about to expire are refreshed by sending a new request with the same URLs. // Returns the time point at which this should next be scheduled to run. - platform::Clock::time_point RefreshWatches(); + Clock::time_point RefreshWatches(); private: // Handles Presentation API URL availability requests and watches for one @@ -82,7 +82,7 @@ class UrlAvailabilityRequester { }; struct Watch { - platform::Clock::time_point deadline; + Clock::time_point deadline; std::vector<std::string> urls; }; @@ -97,7 +97,7 @@ class UrlAvailabilityRequester { void RequestUrlAvailabilities(std::vector<std::string> urls); ErrorOr<uint64_t> SendRequest(uint64_t request_id, const std::vector<std::string>& urls); - platform::Clock::time_point RefreshWatches(platform::Clock::time_point now); + Clock::time_point RefreshWatches(Clock::time_point now); Error::Code UpdateAvailabilities( const std::vector<std::string>& urls, const std::vector<msgs::UrlAvailability>& availabilities); @@ -117,7 +117,7 @@ class UrlAvailabilityRequester { msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) override; + Clock::time_point now) override; UrlAvailabilityRequester* const listener; @@ -138,7 +138,7 @@ class UrlAvailabilityRequester { std::map<std::string, msgs::UrlAvailability> known_availability_by_url; }; - const platform::ClockNowFunctionPtr now_function_; + const ClockNowFunctionPtr now_function_; std::map<std::string, std::vector<ReceiverObserver*>> observers_by_url_; diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc index a0e19639b84..578c861220a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc @@ -46,12 +46,11 @@ class MockReceiverObserver : public ReceiverObserver { class UrlAvailabilityRequesterTest : public Test { public: UrlAvailabilityRequesterTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(milliseconds(1298424))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = + std::make_unique<FakeClock>(Clock::time_point(milliseconds(1298424))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); info1_ = {service_id_, friendly_name_, 1, quic_bridge_->kReceiverEndpoint}; } @@ -86,16 +85,16 @@ class UrlAvailabilityRequesterTest : public Test { void ExpectStreamMessage(MockMessageCallback* mock_callback, msgs::PresentationUrlAvailabilityRequest* request) { EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _)) - .WillOnce(Invoke([request](uint64_t endpoint_id, uint64_t cid, - msgs::Type message_type, - const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) { - ssize_t request_result_size = - msgs::DecodePresentationUrlAvailabilityRequest( - buffer, buffer_size, request); - OSP_DCHECK_GT(request_result_size, 0); - return request_result_size; - })); + .WillOnce( + Invoke([request](uint64_t endpoint_id, uint64_t cid, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + ssize_t request_result_size = + msgs::DecodePresentationUrlAvailabilityRequest( + buffer, buffer_size, request); + OSP_DCHECK_GT(request_result_size, 0); + return request_result_size; + })); } void SendAvailabilityResponse( @@ -126,12 +125,12 @@ class UrlAvailabilityRequesterTest : public Test { stream->Write(buffer.data(), buffer.size()); } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; MockMessageCallback mock_callback_; MessageDemuxer::MessageWatch availability_watch_; std::unique_ptr<FakeQuicBridge> quic_bridge_; - UrlAvailabilityRequester listener_{platform::FakeClock::now}; + UrlAvailabilityRequester listener_{FakeClock::now}; std::string url1_{"https://example.com/foo.html"}; std::string url2_{"https://example.com/bar.html"}; diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc index 90b14590396..c4a72b7970b 100644 --- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc @@ -20,10 +20,10 @@ std::unique_ptr<ProtocolConnectionClient> ProtocolConnectionClientFactory::Create( MessageDemuxer* demuxer, ProtocolConnectionServiceObserver* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { return std::make_unique<QuicClient>( demuxer, std::make_unique<QuicConnectionFactoryImpl>(task_runner), - observer, &platform::Clock::now, task_runner); + observer, &Clock::now, task_runner); } } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc index e0a3b03455d..2984429c5b8 100644 --- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc @@ -21,10 +21,10 @@ ProtocolConnectionServerFactory::Create( const ServerConfig& config, MessageDemuxer* demuxer, ProtocolConnectionServer::Observer* observer, - platform::TaskRunner* task_runner) { + TaskRunner* task_runner) { return std::make_unique<QuicServer>( config, demuxer, std::make_unique<QuicConnectionFactoryImpl>(task_runner), - observer, &platform::Clock::now, task_runner); + observer, &Clock::now, task_runner); } } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn index 9b0c80288ec..221af394bdb 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn +++ b/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn @@ -16,6 +16,7 @@ source_set("quic") { deps = [ "../../../platform", + "../../../third_party/abseil", "../../../util", "../../public", ] @@ -34,6 +35,7 @@ source_set("test_support") { deps = [ "../../../platform", + "../../../third_party/abseil", "../../../third_party/googletest:gmock", "../../../util", "../../msgs", diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc index 41303ad7629..7d0a310301e 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc @@ -19,8 +19,8 @@ QuicClient::QuicClient( MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, ProtocolConnectionServiceObserver* observer, - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner) + ClockNowFunctionPtr now_function, + TaskRunner* task_runner) : ProtocolConnectionClient(demuxer, observer), connection_factory_(std::move(connection_factory)), cleanup_alarm_(now_function, task_runner) {} @@ -63,8 +63,7 @@ void QuicClient::Cleanup() { } delete_connections_.clear(); - constexpr platform::Clock::duration kQuicCleanupPeriod = - std::chrono::milliseconds(500); + constexpr Clock::duration kQuicCleanupPeriod = std::chrono::milliseconds(500); if (state_ != State::kStopped) { cleanup_alarm_.ScheduleFromNow([this] { Cleanup(); }, kQuicCleanupPeriod); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h index 50fb5503971..548cbff51f5 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h @@ -43,8 +43,8 @@ class QuicClient final : public ProtocolConnectionClient, QuicClient(MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, ProtocolConnectionServiceObserver* observer, - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner); + ClockNowFunctionPtr now_function, + TaskRunner* task_runner); ~QuicClient() override; // ProtocolConnectionClient overrides. @@ -103,7 +103,7 @@ class QuicClient final : public ProtocolConnectionClient, // Maps an IPEndpoint to a generated endpoint ID. This is used to insulate // callers from post-handshake changes to a connections actual peer endpoint. - std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_; + std::map<IPEndpoint, uint64_t> endpoint_map_; // Value that will be used for the next new endpoint in a Connect call. uint64_t next_endpoint_id_ = 0; @@ -119,8 +119,7 @@ class QuicClient final : public ProtocolConnectionClient, // Maps endpoint addresses to data about connections that haven't successfully // completed the QUIC handshake. - std::map<IPEndpoint, PendingConnectionData, IPEndpointComparator> - pending_connections_; + std::map<IPEndpoint, PendingConnectionData> pending_connections_; // Maps endpoint IDs to data about connections that have successfully // completed the QUIC handshake. diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc index 954fa11ce5d..71a74ab4165 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc @@ -60,12 +60,11 @@ class ConnectionCallback final class QuicClientTest : public ::testing::Test { public: QuicClientTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(std::chrono::milliseconds(1298424))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::milliseconds(1298424))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); } protected: @@ -100,17 +99,16 @@ class QuicClientTest : public ::testing::Test { mock_message_callback, OnStreamMessage(0, connection->id(), msgs::Type::kPresentationConnectionMessage, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_message]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionMessage( - buffer, buffer_size, &received_message); - if (decode_result < 0) - return ErrorOr<size_t>(Error::Code::kCborParsing); - return ErrorOr<size_t>(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_message]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionMessage( + buffer, buffer_size, &received_message); + if (decode_result < 0) + return ErrorOr<size_t>(Error::Code::kCborParsing); + return ErrorOr<size_t>(decode_result); + })); quic_bridge_->RunTasksUntilIdle(); ASSERT_GT(decode_result, 0); @@ -121,8 +119,8 @@ class QuicClientTest : public ::testing::Test { EXPECT_EQ(received_message.message.str, message.message.str); } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; std::unique_ptr<FakeQuicBridge> quic_bridge_; QuicClient* client_; }; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h index 2f1cb9360cf..e00e25a044d 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h @@ -37,7 +37,7 @@ class QuicStream { uint64_t id_; }; -class QuicConnection : public platform::UdpSocket::Client { +class QuicConnection : public UdpSocket::Client { public: class Delegate { public: diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h index 2e3e10b38d0..f396419d4d6 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h @@ -18,7 +18,7 @@ namespace osp { // This interface provides a way to make new QUIC connections to endpoints. It // also provides a way to receive incoming QUIC connections (as a server). -class QuicConnectionFactory : public platform::UdpSocket::Client { +class QuicConnectionFactory : public UdpSocket::Client { public: class ServerDelegate { public: diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc index 0ecc17e2aea..6514b79119c 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc @@ -18,13 +18,12 @@ #include "util/logging.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace osp { + class QuicTaskRunner final : public ::base::TaskRunner { public: - explicit QuicTaskRunner(platform::TaskRunner* task_runner); + explicit QuicTaskRunner(openscreen::TaskRunner* task_runner); ~QuicTaskRunner() override; void RunTasks(); @@ -37,11 +36,12 @@ class QuicTaskRunner final : public ::base::TaskRunner { bool RunsTasksInCurrentSequence() const override; private: - platform::TaskRunner* const task_runner_; + openscreen::TaskRunner* const task_runner_; }; -QuicTaskRunner::QuicTaskRunner(platform::TaskRunner* task_runner) +QuicTaskRunner::QuicTaskRunner(openscreen::TaskRunner* task_runner) : task_runner_(task_runner) {} + QuicTaskRunner::~QuicTaskRunner() = default; void QuicTaskRunner::RunTasks() {} @@ -49,8 +49,7 @@ void QuicTaskRunner::RunTasks() {} bool QuicTaskRunner::PostDelayedTask(const ::base::Location& whence, ::base::OnceClosure task, ::base::TimeDelta delay) { - platform::Clock::duration wait = - platform::Clock::duration(delay.InMilliseconds()); + Clock::duration wait = Clock::duration(delay.InMilliseconds()); task_runner_->PostTaskWithDelay( [closure = std::move(task)]() mutable { std::move(closure).Run(); }, wait); @@ -61,8 +60,7 @@ bool QuicTaskRunner::RunsTasksInCurrentSequence() const { return true; } -QuicConnectionFactoryImpl::QuicConnectionFactoryImpl( - platform::TaskRunner* task_runner) +QuicConnectionFactoryImpl::QuicConnectionFactoryImpl(TaskRunner* task_runner) : task_runner_(task_runner) { quic_task_runner_ = ::base::MakeRefCounted<QuicTaskRunner>(task_runner); alarm_factory_ = std::make_unique<::net::QuicChromiumAlarmFactory>( @@ -90,34 +88,31 @@ void QuicConnectionFactoryImpl::SetServerDelegate( // create/bind errors occur. Maybe return an Error immediately, and undo // partial progress (i.e. "unwatch" all the sockets and call // sockets_.clear() to close the sockets)? - auto create_result = - platform::UdpSocket::Create(task_runner_, this, endpoint); + auto create_result = UdpSocket::Create(task_runner_, this, endpoint); if (!create_result) { OSP_LOG_ERROR << "failed to create socket (for " << endpoint << "): " << create_result.error().message(); continue; } - platform::UdpSocketUniquePtr server_socket = - std::move(create_result.value()); + std::unique_ptr<UdpSocket> server_socket = std::move(create_result.value()); server_socket->Bind(); sockets_.emplace_back(std::move(server_socket)); } } -void QuicConnectionFactoryImpl::OnRead( - platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet_or_error) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionFactoryImpl::OnRead"); +void QuicConnectionFactoryImpl::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet_or_error) { + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionFactoryImpl::OnRead"); if (packet_or_error.is_error()) { return; } - platform::UdpPacket packet = std::move(packet_or_error.value()); + UdpPacket packet = std::move(packet_or_error.value()); // Ensure that |packet.socket| is one of the instances owned by // QuicConnectionFactoryImpl. auto packet_ptr = &packet; OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(), - [packet_ptr](const platform::UdpSocketUniquePtr& s) { + [packet_ptr](const std::unique_ptr<UdpSocket>& s) { return s.get() == packet_ptr->socket(); }) != sockets_.end()); @@ -153,15 +148,14 @@ void QuicConnectionFactoryImpl::OnRead( std::unique_ptr<QuicConnection> QuicConnectionFactoryImpl::Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) { - auto create_result = - platform::UdpSocket::Create(task_runner_, this, endpoint); + auto create_result = UdpSocket::Create(task_runner_, this, endpoint); if (!create_result) { OSP_LOG_ERROR << "failed to create socket: " << create_result.error().message(); // TODO(mfoltz): This method should return ErrorOr<uni_ptr<QuicConnection>>. return nullptr; } - platform::UdpSocketUniquePtr socket = std::move(create_result.value()); + std::unique_ptr<UdpSocket> socket = std::move(create_result.value()); auto transport = std::make_unique<UdpTransport>(socket.get(), endpoint); ::quic::QuartcSessionConfig session_config; @@ -192,7 +186,7 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) { return entry.second.connection == connection; }); OSP_DCHECK(entry != connections_.end()); - platform::UdpSocket* const socket = entry->second.socket; + UdpSocket* const socket = entry->second.socket; connections_.erase(entry); // If none of the remaining |connections_| reference the socket, close/destroy @@ -203,7 +197,7 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) { }) == connections_.end()) { auto socket_it = std::find_if(sockets_.begin(), sockets_.end(), - [socket](const platform::UdpSocketUniquePtr& s) { + [socket](const std::unique_ptr<UdpSocket>& s) { return s.get() == socket; }); OSP_DCHECK(socket_it != sockets_.end()); @@ -211,13 +205,11 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) { } } -void QuicConnectionFactoryImpl::OnError(platform::UdpSocket* socket, - Error error) { +void QuicConnectionFactoryImpl::OnError(UdpSocket* socket, Error error) { OSP_LOG_ERROR << "failed to configure socket " << error.message(); } -void QuicConnectionFactoryImpl::OnSendError(platform::UdpSocket* socket, - Error error) { +void QuicConnectionFactoryImpl::OnSendError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h index 8e29579d3b6..e3588f6a2f0 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h @@ -22,14 +22,13 @@ class QuicTaskRunner; class QuicConnectionFactoryImpl final : public QuicConnectionFactory { public: - QuicConnectionFactoryImpl(platform::TaskRunner* task_runner); + QuicConnectionFactoryImpl(TaskRunner* task_runner); ~QuicConnectionFactoryImpl() override; // UdpSocket::Client overrides. - void OnError(platform::UdpSocket* socket, Error error) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, @@ -48,18 +47,18 @@ class QuicConnectionFactoryImpl final : public QuicConnectionFactory { ServerDelegate* server_delegate_ = nullptr; - std::vector<platform::UdpSocketUniquePtr> sockets_; + std::vector<std::unique_ptr<UdpSocket>> sockets_; struct OpenConnection { QuicConnection* connection; - platform::UdpSocket* socket; // References one of the owned |sockets_|. + UdpSocket* socket; // References one of the owned |sockets_|. }; - std::map<IPEndpoint, OpenConnection, IPEndpointComparator> connections_; + std::map<IPEndpoint, OpenConnection> connections_; // NOTE: Must be provided in constructor and stored as an instance variable // rather than using the static accessor method to allow for UTs to mock this // layer. - platform::TaskRunner* const task_runner_; + TaskRunner* const task_runner_; }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc index 8ee43c332b8..778d6b4de82 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc @@ -14,13 +14,10 @@ #include "util/logging.h" #include "util/trace_logging.h" -using openscreen::platform::TraceCategory; - namespace openscreen { namespace osp { -UdpTransport::UdpTransport(platform::UdpSocket* socket, - const IPEndpoint& destination) +UdpTransport::UdpTransport(UdpSocket* socket, const IPEndpoint& destination) : socket_(socket), destination_(destination) { OSP_DCHECK(socket_); } @@ -33,7 +30,7 @@ UdpTransport& UdpTransport::operator=(UdpTransport&&) noexcept = default; int UdpTransport::Write(const char* buffer, size_t buffer_length, const PacketInfo& info) { - TRACE_SCOPED(TraceCategory::Quic, "UdpTransport::Write"); + TRACE_SCOPED(TraceCategory::kQuic, "UdpTransport::Write"); socket_->SendMessage(buffer, buffer_length, destination_); OSP_DCHECK_LE(buffer_length, static_cast<size_t>(std::numeric_limits<int>::max())); @@ -49,7 +46,7 @@ QuicStreamImpl::QuicStreamImpl(QuicStream::Delegate* delegate, QuicStreamImpl::~QuicStreamImpl() = default; void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) { - TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::Write"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Write"); OSP_DCHECK(!stream_->write_side_closed()); stream_->WriteOrBufferData( ::quic::QuicStringPiece(reinterpret_cast<const char*>(data), data_size), @@ -57,7 +54,7 @@ void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) { } void QuicStreamImpl::CloseWriteEnd() { - TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::CloseWriteEnd"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::CloseWriteEnd"); if (!stream_->write_side_closed()) stream_->FinishWriting(); } @@ -65,12 +62,12 @@ void QuicStreamImpl::CloseWriteEnd() { void QuicStreamImpl::OnReceived(::quic::QuartcStream* stream, const char* data, size_t data_size) { - TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnReceived"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::OnReceived"); delegate_->OnReceived(this, data, data_size); } void QuicStreamImpl::OnClose(::quic::QuartcStream* stream) { - TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnClose"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::OnClose"); delegate_->OnClose(stream->id()); } @@ -79,9 +76,8 @@ void QuicStreamImpl::OnBufferChanged(::quic::QuartcStream* stream) {} // Passes a received UDP packet to the QUIC implementation. If this contains // any stream data, it will be passed automatically to the relevant // QuicStream::Delegate objects. -void QuicConnectionImpl::OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> data) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnRead"); +void QuicConnectionImpl::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) { + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnRead"); if (data.is_error()) { TRACE_SET_RESULT(data.error()); return; @@ -91,12 +87,12 @@ void QuicConnectionImpl::OnRead(platform::UdpSocket* socket, reinterpret_cast<const char*>(data.value().data()), data.value().size()); } -void QuicConnectionImpl::OnSendError(platform::UdpSocket* socket, Error error) { +void QuicConnectionImpl::OnSendError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } -void QuicConnectionImpl::OnError(platform::UdpSocket* socket, Error error) { +void QuicConnectionImpl::OnError(UdpSocket* socket, Error error) { // TODO(crbug.com/openscreen/67): Implement this method. OSP_UNIMPLEMENTED(); } @@ -110,7 +106,7 @@ QuicConnectionImpl::QuicConnectionImpl( parent_factory_(parent_factory), session_(std::move(session)), udp_transport_(std::move(udp_transport)) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::QuicConnectionImpl"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::QuicConnectionImpl"); session_->SetDelegate(this); session_->OnTransportCanWrite(); session_->StartCryptoHandshake(); @@ -120,24 +116,24 @@ QuicConnectionImpl::~QuicConnectionImpl() = default; std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream( QuicStream::Delegate* delegate) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::MakeOutgoingStream"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::MakeOutgoingStream"); ::quic::QuartcStream* stream = session_->CreateOutgoingDynamicStream(); return std::make_unique<QuicStreamImpl>(delegate, stream); } void QuicConnectionImpl::Close() { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::Close"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::Close"); session_->CloseConnection("closed"); } void QuicConnectionImpl::OnCryptoHandshakeComplete() { - TRACE_SCOPED(TraceCategory::Quic, + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnCryptoHandshakeComplete"); delegate_->OnCryptoHandshakeComplete(session_->connection_id()); } void QuicConnectionImpl::OnIncomingStream(::quic::QuartcStream* stream) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnIncomingStream"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnIncomingStream"); auto public_stream = std::make_unique<QuicStreamImpl>( delegate_->NextStreamDelegate(session_->connection_id(), stream->id()), stream); @@ -150,7 +146,7 @@ void QuicConnectionImpl::OnConnectionClosed( ::quic::QuicErrorCode error_code, const ::quic::QuicString& error_details, ::quic::ConnectionCloseSource source) { - TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnConnectionClosed"); + TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnConnectionClosed"); parent_factory_->OnConnectionClosed(this); delegate_->OnConnectionClosed(session_->connection_id()); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h index c7697ec117b..e2609deb694 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h @@ -26,7 +26,7 @@ class QuicConnectionFactoryImpl; class UdpTransport final : public ::quic::QuartcPacketTransport { public: - UdpTransport(platform::UdpSocket* socket, const IPEndpoint& destination); + UdpTransport(UdpSocket* socket, const IPEndpoint& destination); UdpTransport(UdpTransport&&) noexcept; ~UdpTransport() override; @@ -37,10 +37,10 @@ class UdpTransport final : public ::quic::QuartcPacketTransport { size_t buffer_length, const PacketInfo& info) override; - platform::UdpSocket* socket() const { return socket_; } + UdpSocket* socket() const { return socket_; } private: - platform::UdpSocket* socket_; + UdpSocket* socket_; IPEndpoint destination_; }; @@ -76,10 +76,9 @@ class QuicConnectionImpl final : public QuicConnection, ~QuicConnectionImpl() override; // UdpSocket::Client overrides. - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> data) override; - void OnError(platform::UdpSocket* socket, Error error) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) override; + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; // QuicConnection overrides. std::unique_ptr<QuicStream> MakeOutgoingStream( diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc index 17a90973f94..e1afc58c712 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc @@ -19,8 +19,8 @@ QuicServer::QuicServer( MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, ProtocolConnectionServer::Observer* observer, - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner) + ClockNowFunctionPtr now_function, + TaskRunner* task_runner) : ProtocolConnectionServer(demuxer, observer), connection_endpoints_(config.connection_endpoints), connection_factory_(std::move(connection_factory)), @@ -80,8 +80,7 @@ void QuicServer::Cleanup() { } delete_connections_.clear(); - constexpr platform::Clock::duration kQuicCleanupPeriod = - std::chrono::milliseconds(500); + constexpr Clock::duration kQuicCleanupPeriod = std::chrono::milliseconds(500); if (state_ != State::kStopped) { cleanup_alarm_.ScheduleFromNow([this] { Cleanup(); }, kQuicCleanupPeriod); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h index 48652d53b16..ad195635461 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h @@ -37,8 +37,8 @@ class QuicServer final : public ProtocolConnectionServer, MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, ProtocolConnectionServer::Observer* observer, - platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner); + ClockNowFunctionPtr now_function, + TaskRunner* task_runner); ~QuicServer() override; // ProtocolConnectionServer overrides. @@ -84,15 +84,14 @@ class QuicServer final : public ProtocolConnectionServer, // Maps an IPEndpoint to a generated endpoint ID. This is used to insulate // callers from post-handshake changes to a connections actual peer endpoint. - std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_; + std::map<IPEndpoint, uint64_t> endpoint_map_; // Value that will be used for the next new endpoint in a Connect call. uint64_t next_endpoint_id_ = 0; // Maps endpoint addresses to data about connections that haven't successfully // completed the QUIC handshake. - std::map<IPEndpoint, ServiceConnectionData, IPEndpointComparator> - pending_connections_; + std::map<IPEndpoint, ServiceConnectionData> pending_connections_; // Maps endpoint IDs to data about connections that have successfully // completed the QUIC handshake. diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc index ed1c33627a8..12ad8c33b06 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc @@ -49,12 +49,11 @@ class MockConnectionObserver final : public ProtocolConnection::Observer { class QuicServerTest : public Test { public: QuicServerTest() { - fake_clock_ = std::make_unique<platform::FakeClock>( - platform::Clock::time_point(std::chrono::milliseconds(1298424))); - task_runner_ = - std::make_unique<platform::FakeTaskRunner>(fake_clock_.get()); - quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), - platform::FakeClock::now); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::milliseconds(1298424))); + task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); + quic_bridge_ = + std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); } protected: @@ -103,17 +102,16 @@ class QuicServerTest : public Test { EXPECT_CALL(mock_message_callback, OnStreamMessage( 0, _, msgs::Type::kPresentationConnectionMessage, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_message]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionMessage( - buffer, buffer_size, &received_message); - if (decode_result < 0) - return ErrorOr<size_t>(Error::Code::kCborParsing); - return ErrorOr<size_t>(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_message]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionMessage( + buffer, buffer_size, &received_message); + if (decode_result < 0) + return ErrorOr<size_t>(Error::Code::kCborParsing); + return ErrorOr<size_t>(decode_result); + })); quic_bridge_->RunTasksUntilIdle(); ASSERT_GT(decode_result, 0); @@ -124,8 +122,8 @@ class QuicServerTest : public Test { EXPECT_EQ(received_message.message.str, message.message.str); } - std::unique_ptr<platform::FakeClock> fake_clock_; - std::unique_ptr<platform::FakeTaskRunner> task_runner_; + std::unique_ptr<FakeClock> fake_clock_; + std::unique_ptr<FakeTaskRunner> task_runner_; std::unique_ptr<FakeQuicBridge> quic_bridge_; QuicServer* server_; }; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc index c04ef4331c3..52f7241cb3a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc @@ -61,16 +61,15 @@ std::unique_ptr<FakeQuicStream> FakeQuicConnection::MakeIncomingStream() { return result; } -void FakeQuicConnection::OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> data) { +void FakeQuicConnection::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) { OSP_NOTREACHED() << "data should go directly to fake streams"; } -void FakeQuicConnection::OnSendError(platform::UdpSocket* socket, Error error) { +void FakeQuicConnection::OnSendError(UdpSocket* socket, Error error) { OSP_NOTREACHED() << "data should go directly to fake streams"; } -void FakeQuicConnection::OnError(platform::UdpSocket* socket, Error error) { +void FakeQuicConnection::OnError(UdpSocket* socket, Error error) { OSP_NOTREACHED() << "data should go directly to fake streams"; } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h index 47943f0b8df..9b11d7b4d55 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h @@ -58,10 +58,9 @@ class FakeQuicConnection final : public QuicConnection { std::unique_ptr<FakeQuicStream> MakeIncomingStream(); // UdpSocket::Client overrides. - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> data) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnError(platform::UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnError(UdpSocket* socket, Error error) override; // QuicConnection overrides. std::unique_ptr<QuicStream> MakeOutgoingStream( diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc index c68d256c6f6..47f3ab8757d 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc @@ -167,19 +167,17 @@ std::unique_ptr<QuicConnection> FakeClientQuicConnectionFactory::Connect( return bridge_->Connect(endpoint, connection_delegate); } -void FakeClientQuicConnectionFactory::OnError(platform::UdpSocket* socket, - Error error) { +void FakeClientQuicConnectionFactory::OnError(UdpSocket* socket, Error error) { OSP_UNIMPLEMENTED(); } -void FakeClientQuicConnectionFactory::OnSendError(platform::UdpSocket* socket, +void FakeClientQuicConnectionFactory::OnSendError(UdpSocket* socket, Error error) { OSP_UNIMPLEMENTED(); } -void FakeClientQuicConnectionFactory::OnRead( - platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) { +void FakeClientQuicConnectionFactory::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet) { bridge_->RunTasks(true); idle_ = bridge_->client_idle(); } @@ -207,19 +205,17 @@ std::unique_ptr<QuicConnection> FakeServerQuicConnectionFactory::Connect( return nullptr; } -void FakeServerQuicConnectionFactory::OnError(platform::UdpSocket* socket, - Error error) { +void FakeServerQuicConnectionFactory::OnError(UdpSocket* socket, Error error) { OSP_UNIMPLEMENTED(); } -void FakeServerQuicConnectionFactory::OnSendError(platform::UdpSocket* socket, +void FakeServerQuicConnectionFactory::OnSendError(UdpSocket* socket, Error error) { OSP_UNIMPLEMENTED(); } -void FakeServerQuicConnectionFactory::OnRead( - platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) { +void FakeServerQuicConnectionFactory::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet) { bridge_->RunTasks(false); idle_ = bridge_->server_idle(); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h index 128e6372241..2aac1145359 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h @@ -55,10 +55,9 @@ class FakeClientQuicConnectionFactory final : public QuicConnectionFactory { ~FakeClientQuicConnectionFactory() override; // UdpSocket::Client overrides. - void OnError(platform::UdpSocket* socket, Error error) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, @@ -69,7 +68,7 @@ class FakeClientQuicConnectionFactory final : public QuicConnectionFactory { bool idle() const { return idle_; } - std::unique_ptr<platform::UdpSocket> socket_; + std::unique_ptr<UdpSocket> socket_; private: FakeQuicConnectionFactoryBridge* bridge_; @@ -83,10 +82,9 @@ class FakeServerQuicConnectionFactory final : public QuicConnectionFactory { ~FakeServerQuicConnectionFactory() override; // UdpSocket::Client overrides. - void OnError(platform::UdpSocket* socket, Error error) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; + void OnError(UdpSocket* socket, Error error) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc index ed5a7007602..2ca327139f1 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc @@ -14,8 +14,8 @@ namespace openscreen { namespace osp { -FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner, - platform::ClockNowFunctionPtr now_function) +FakeQuicBridge::FakeQuicBridge(FakeTaskRunner* task_runner, + ClockNowFunctionPtr now_function) : task_runner_(task_runner) { fake_bridge = std::make_unique<FakeQuicConnectionFactoryBridge>(kControllerEndpoint); @@ -27,8 +27,8 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner, auto fake_client_factory = std::make_unique<FakeClientQuicConnectionFactory>(fake_bridge.get()); - client_socket_ = std::make_unique<platform::FakeUdpSocket>( - task_runner_, fake_client_factory.get()); + client_socket_ = + std::make_unique<FakeUdpSocket>(task_runner_, fake_client_factory.get()); quic_client = std::make_unique<QuicClient>( controller_demuxer.get(), std::move(fake_client_factory), @@ -36,8 +36,8 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner, auto fake_server_factory = std::make_unique<FakeServerQuicConnectionFactory>(fake_bridge.get()); - server_socket_ = std::make_unique<platform::FakeUdpSocket>( - task_runner_, fake_server_factory.get()); + server_socket_ = + std::make_unique<FakeUdpSocket>(task_runner_, fake_server_factory.get()); ServerConfig config; config.connection_endpoints.push_back(kReceiverEndpoint); quic_server = std::make_unique<QuicServer>( @@ -51,13 +51,13 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner, FakeQuicBridge::~FakeQuicBridge() = default; void FakeQuicBridge::PostClientPacket() { - platform::UdpPacket packet; + UdpPacket packet; packet.set_socket(client_socket_.get()); client_socket_->MockReceivePacket(std::move(packet)); } void FakeQuicBridge::PostServerPacket() { - platform::UdpPacket packet; + UdpPacket packet; packet.set_socket(server_socket_.get()); server_socket_->MockReceivePacket(std::move(packet)); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h index 3e8719aa25a..a7ae7c33b63 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h @@ -56,8 +56,7 @@ class MockServerObserver : public ProtocolConnectionServer::Observer { class FakeQuicBridge { public: - FakeQuicBridge(platform::FakeTaskRunner* task_runner, - platform::ClockNowFunctionPtr now_function); + FakeQuicBridge(FakeTaskRunner* task_runner, ClockNowFunctionPtr now_function); ~FakeQuicBridge(); const IPEndpoint kControllerEndpoint{{192, 168, 1, 3}, 4321}; @@ -79,10 +78,10 @@ class FakeQuicBridge { void PostPacketsUntilIdle(); FakeClientQuicConnectionFactory* GetClientFactory(); FakeServerQuicConnectionFactory* GetServerFactory(); - platform::FakeTaskRunner* task_runner_; + FakeTaskRunner* task_runner_; - std::unique_ptr<platform::FakeUdpSocket> client_socket_; - std::unique_ptr<platform::FakeUdpSocket> server_socket_; + std::unique_ptr<FakeUdpSocket> client_socket_; + std::unique_ptr<FakeUdpSocket> server_socket_; }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc index f0963014601..43e5479c9f2 100644 --- a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc @@ -4,6 +4,8 @@ #include "osp/impl/service_listener_impl.h" +#include <algorithm> + #include "platform/base/error.h" #include "util/logging.h" diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc index 7930eedf779..371e8789acc 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc @@ -16,7 +16,7 @@ FakeMdnsPlatformService::~FakeMdnsPlatformService() = default; std::vector<MdnsPlatformService::BoundInterface> FakeMdnsPlatformService::RegisterInterfaces( - const std::vector<platform::NetworkInterfaceIndex>& whitelist) { + const std::vector<NetworkInterfaceIndex>& whitelist) { OSP_CHECK(registered_interfaces_.empty()); if (whitelist.empty()) { registered_interfaces_ = interfaces_; diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h index 81e148c16c8..2c5d6b88794 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h @@ -23,8 +23,8 @@ class FakeMdnsPlatformService final : public MdnsPlatformService { // PlatformService overrides. std::vector<BoundInterface> RegisterInterfaces( - const std::vector<platform::NetworkInterfaceIndex>& - interface_index_whitelist) override; + const std::vector<NetworkInterfaceIndex>& interface_index_whitelist) + override; void DeregisterInterfaces( const std::vector<BoundInterface>& registered_interfaces) override; diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc index 69595ac3598..c2058d4d254 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc @@ -12,32 +12,33 @@ namespace openscreen { namespace osp { namespace { -platform::UdpSocket* const kDefaultSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(16)); -platform::UdpSocket* const kSecondSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(24)); +UdpSocket* const kDefaultSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16)); +UdpSocket* const kSecondSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24)); class FakeMdnsPlatformServiceTest : public ::testing::Test { protected: const uint8_t mac1_[6] = {11, 22, 33, 44, 55, 66}; const uint8_t mac2_[6] = {12, 23, 34, 45, 56, 67}; - const platform::IPSubnet subnet1_{IPAddress{192, 168, 3, 2}, 24}; - const platform::IPSubnet subnet2_{ - IPAddress{1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 8}, 24}; + const IPSubnet subnet1_{IPAddress{192, 168, 3, 2}, 24}; + const IPSubnet subnet2_{ + IPAddress{0x0102, 0x0304, 0x0504, 0x0302, 0x0102, 0x0304, 0x0506, 0x0708}, + 24}; std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{ MdnsPlatformService::BoundInterface{ - platform::InterfaceInfo{1, - mac1_, - "eth0", - platform::InterfaceInfo::Type::kEthernet, - {subnet1_}}, + InterfaceInfo{1, + mac1_, + "eth0", + InterfaceInfo::Type::kEthernet, + {subnet1_}}, subnet1_, kDefaultSocket}, MdnsPlatformService::BoundInterface{ - platform::InterfaceInfo{2, - mac2_, - "eth1", - platform::InterfaceInfo::Type::kEthernet, - {subnet2_}}, + InterfaceInfo{2, + mac2_, + "eth1", + InterfaceInfo::Type::kEthernet, + {subnet2_}}, subnet2_, kSecondSocket}}; }; diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc index 37e744348ee..fad1c125b30 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc @@ -17,7 +17,7 @@ constexpr char kLocalDomain[] = "local"; PtrEvent MakePtrEvent(const std::string& service_instance, const std::string& service_type, const std::string& service_protocol, - platform::UdpSocket* socket) { + UdpSocket* socket) { const auto labels = std::vector<std::string>{service_instance, service_type, service_protocol, kLocalDomain}; ErrorOr<DomainName> full_instance_name = @@ -33,7 +33,7 @@ SrvEvent MakeSrvEvent(const std::string& service_instance, const std::string& service_protocol, const std::string& hostname, uint16_t port, - platform::UdpSocket* socket) { + UdpSocket* socket) { const auto instance_labels = std::vector<std::string>{ service_instance, service_type, service_protocol, kLocalDomain}; ErrorOr<DomainName> full_instance_name = @@ -54,7 +54,7 @@ TxtEvent MakeTxtEvent(const std::string& service_instance, const std::string& service_type, const std::string& service_protocol, const std::vector<std::string>& txt_lines, - platform::UdpSocket* socket) { + UdpSocket* socket) { const auto labels = std::vector<std::string>{service_instance, service_type, service_protocol, kLocalDomain}; ErrorOr<DomainName> domain_name = @@ -67,7 +67,7 @@ TxtEvent MakeTxtEvent(const std::string& service_instance, AEvent MakeAEvent(const std::string& hostname, IPAddress address, - platform::UdpSocket* socket) { + UdpSocket* socket) { const auto labels = std::vector<std::string>{hostname, kLocalDomain}; ErrorOr<DomainName> domain_name = DomainName::FromLabels(labels.begin(), labels.end()); @@ -79,7 +79,7 @@ AEvent MakeAEvent(const std::string& hostname, AaaaEvent MakeAaaaEvent(const std::string& hostname, IPAddress address, - platform::UdpSocket* socket) { + UdpSocket* socket) { const auto labels = std::vector<std::string>{hostname, kLocalDomain}; ErrorOr<DomainName> domain_name = DomainName::FromLabels(labels.begin(), labels.end()); @@ -97,7 +97,7 @@ void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder, uint16_t port, const std::vector<std::string>& txt_lines, const IPAddress& address, - platform::UdpSocket* socket) { + UdpSocket* socket) { mdns_responder->AddPtrEvent( MakePtrEvent(service_instance, service_name, service_protocol, socket)); mdns_responder->AddSrvEvent(MakeSrvEvent(service_instance, service_name, @@ -202,9 +202,9 @@ Error FakeMdnsResponderAdapter::SetHostLabel(const std::string& host_label) { } Error FakeMdnsResponderAdapter::RegisterInterface( - const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& interface_address, - platform::UdpSocket* socket) { + const InterfaceInfo& interface_info, + const IPSubnet& interface_address, + UdpSocket* socket) { if (!running_) return Error::Code::kOperationInvalid; @@ -218,8 +218,7 @@ Error FakeMdnsResponderAdapter::RegisterInterface( return Error::None(); } -Error FakeMdnsResponderAdapter::DeregisterInterface( - platform::UdpSocket* socket) { +Error FakeMdnsResponderAdapter::DeregisterInterface(UdpSocket* socket) { auto it = std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(), [&socket](const RegisteredInterface& interface) { @@ -232,22 +231,20 @@ Error FakeMdnsResponderAdapter::DeregisterInterface( return Error::None(); } -void FakeMdnsResponderAdapter::OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) { +void FakeMdnsResponderAdapter::OnRead(UdpSocket* socket, + ErrorOr<UdpPacket> packet) { OSP_NOTREACHED() << "Tests should not drive this class with packets"; } -void FakeMdnsResponderAdapter::OnSendError(platform::UdpSocket* socket, - Error error) { +void FakeMdnsResponderAdapter::OnSendError(UdpSocket* socket, Error error) { OSP_NOTREACHED() << "Tests should not drive this class with packets"; } -void FakeMdnsResponderAdapter::OnError(platform::UdpSocket* socket, - Error error) { +void FakeMdnsResponderAdapter::OnError(UdpSocket* socket, Error error) { OSP_NOTREACHED() << "Tests should not drive this class with packets"; } -platform::Clock::duration FakeMdnsResponderAdapter::RunTasks() { +Clock::duration FakeMdnsResponderAdapter::RunTasks() { return std::chrono::seconds(1); } @@ -369,7 +366,7 @@ std::vector<AaaaEvent> FakeMdnsResponderAdapter::TakeAaaaResponses() { } MdnsResponderErrorCode FakeMdnsResponderAdapter::StartPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) { if (!running_) return MdnsResponderErrorCode::kUnknownError; @@ -388,7 +385,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartPtrQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StartSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { if (!running_) return MdnsResponderErrorCode::kUnknownError; @@ -402,7 +399,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartSrvQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StartTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { if (!running_) return MdnsResponderErrorCode::kUnknownError; @@ -416,7 +413,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartTxtQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { if (!running_) return MdnsResponderErrorCode::kUnknownError; @@ -430,7 +427,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { if (!running_) return MdnsResponderErrorCode::kUnknownError; @@ -444,7 +441,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAaaaQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StopPtrQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_type) { auto interface_entry = queries_.find(socket); if (interface_entry == queries_.end()) @@ -463,7 +460,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopPtrQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StopSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { auto interface_entry = queries_.find(socket); if (interface_entry == queries_.end()) @@ -478,7 +475,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopSrvQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StopTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) { auto interface_entry = queries_.find(socket); if (interface_entry == queries_.end()) @@ -493,7 +490,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopTxtQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { auto interface_entry = queries_.find(socket); if (interface_entry == queries_.end()) @@ -508,7 +505,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAQuery( } MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAaaaQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& domain_name) { auto interface_entry = queries_.find(socket); if (interface_entry == queries_.end()) diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h index a7287e45ddb..d4fdad1fea0 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h @@ -18,28 +18,28 @@ class FakeMdnsResponderAdapter; PtrEvent MakePtrEvent(const std::string& service_instance, const std::string& service_type, const std::string& service_protocol, - platform::UdpSocket* socket); + UdpSocket* socket); SrvEvent MakeSrvEvent(const std::string& service_instance, const std::string& service_type, const std::string& service_protocol, const std::string& hostname, uint16_t port, - platform::UdpSocket* socket); + UdpSocket* socket); TxtEvent MakeTxtEvent(const std::string& service_instance, const std::string& service_type, const std::string& service_protocol, const std::vector<std::string>& txt_lines, - platform::UdpSocket* socket); + UdpSocket* socket); AEvent MakeAEvent(const std::string& hostname, IPAddress address, - platform::UdpSocket* socket); + UdpSocket* socket); AaaaEvent MakeAaaaEvent(const std::string& hostname, IPAddress address, - platform::UdpSocket* socket); + UdpSocket* socket); void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder, const std::string& service_instance, @@ -49,14 +49,14 @@ void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder, uint16_t port, const std::vector<std::string>& txt_lines, const IPAddress& address, - platform::UdpSocket* socket); + UdpSocket* socket); class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { public: struct RegisteredInterface { - platform::InterfaceInfo interface_info; - platform::IPSubnet interface_address; - platform::UdpSocket* socket; + InterfaceInfo interface_info; + IPSubnet interface_address; + UdpSocket* socket; }; struct RegisteredService { @@ -99,10 +99,9 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { bool running() const { return running_; } // UdpSocket::Client overrides. - void OnRead(platform::UdpSocket* socket, - ErrorOr<platform::UdpPacket> packet) override; - void OnSendError(platform::UdpSocket* socket, Error error) override; - void OnError(platform::UdpSocket* socket, Error error) override; + void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; + void OnSendError(UdpSocket* socket, Error error) override; + void OnError(UdpSocket* socket, Error error) override; // MdnsResponderAdapter overrides. Error Init() override; @@ -112,12 +111,12 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { // TODO(btolsch): Reject/OSP_CHECK events that don't match any registered // interface? - Error RegisterInterface(const platform::InterfaceInfo& interface_info, - const platform::IPSubnet& interface_address, - platform::UdpSocket* socket) override; - Error DeregisterInterface(platform::UdpSocket* socket) override; + Error RegisterInterface(const InterfaceInfo& interface_info, + const IPSubnet& interface_address, + UdpSocket* socket) override; + Error DeregisterInterface(UdpSocket* socket) override; - platform::Clock::duration RunTasks() override; + Clock::duration RunTasks() override; std::vector<PtrEvent> TakePtrResponses() override; std::vector<SrvEvent> TakeSrvResponses() override; @@ -125,30 +124,30 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { std::vector<AEvent> TakeAResponses() override; std::vector<AaaaEvent> TakeAaaaResponses() override; - MdnsResponderErrorCode StartPtrQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket, const DomainName& service_type) override; MdnsResponderErrorCode StartSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; MdnsResponderErrorCode StartTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; - MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartAQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StartAaaaQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StopPtrQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket, const DomainName& service_type) override; MdnsResponderErrorCode StopSrvQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; MdnsResponderErrorCode StopTxtQuery( - platform::UdpSocket* socket, + UdpSocket* socket, const DomainName& service_instance) override; - MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopAQuery(UdpSocket* socket, const DomainName& domain_name) override; - MdnsResponderErrorCode StopAaaaQuery(platform::UdpSocket* socket, + MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket, const DomainName& domain_name) override; MdnsResponderErrorCode RegisterService( @@ -180,7 +179,7 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { bool running_ = false; LifetimeObserver* observer_ = nullptr; - std::map<platform::UdpSocket*, InterfaceQueries> queries_; + std::map<UdpSocket*, InterfaceQueries> queries_; // NOTE: One of many simplifications here is that there is no cache. This // means that calling StartQuery, StopQuery, StartQuery will only return an // event the first time, unless the test also adds the event a second time. diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc index f11ef33bbee..06cec89a5ad 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc @@ -15,10 +15,10 @@ constexpr char kTestServiceInstance[] = "turtle"; constexpr char kTestServiceName[] = "_foo"; constexpr char kTestServiceProtocol[] = "_udp"; -platform::UdpSocket* const kDefaultSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(8)); -platform::UdpSocket* const kSecondSocket = - reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(32)); +UdpSocket* const kDefaultSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(8)); +UdpSocket* const kSecondSocket = + reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(32)); } // namespace @@ -260,18 +260,18 @@ TEST(FakeMdnsResponderAdapterTest, RegisterInterfaces) { ASSERT_TRUE(mdns_responder.running()); EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); - Error result = mdns_responder.RegisterInterface( - platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket); + Error result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, + kDefaultSocket); EXPECT_TRUE(result.ok()); EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - result = mdns_responder.RegisterInterface( - platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket); + result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, + kDefaultSocket); EXPECT_FALSE(result.ok()); EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - result = mdns_responder.RegisterInterface( - platform::InterfaceInfo{}, platform::IPSubnet{}, kSecondSocket); + result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, + kSecondSocket); EXPECT_TRUE(result.ok()); EXPECT_EQ(2u, mdns_responder.registered_interfaces().size()); @@ -286,8 +286,8 @@ TEST(FakeMdnsResponderAdapterTest, RegisterInterfaces) { ASSERT_FALSE(mdns_responder.running()); EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); - result = mdns_responder.RegisterInterface( - platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket); + result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, + kDefaultSocket); EXPECT_FALSE(result.ok()); EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); } diff --git a/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h b/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h index 579739c7dad..82874420b12 100644 --- a/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h +++ b/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h @@ -160,7 +160,7 @@ class RequestResponseHandler : public MessageDemuxer::MessageCallback { msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) override { + Clock::time_point now) override { if (message_type != RequestT::kResponseType) { return 0; } diff --git a/chromium/third_party/openscreen/src/osp/public/client_config.h b/chromium/third_party/openscreen/src/osp/public/client_config.h index 732b52b6fd2..ae172377d61 100644 --- a/chromium/third_party/openscreen/src/osp/public/client_config.h +++ b/chromium/third_party/openscreen/src/osp/public/client_config.h @@ -19,8 +19,8 @@ struct ClientConfig { // The indexes of network interfaces that should be used by the Open Screen // Library. The indexes derive from the values of - // openscreen::platform::InterfaceInfo::index. - std::vector<platform::NetworkInterfaceIndex> interface_indexes; + // openscreen::InterfaceInfo::index. + std::vector<NetworkInterfaceIndex> interface_indexes; }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h index 73118bdaaa9..663d060fbbe 100644 --- a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h @@ -11,9 +11,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -27,7 +25,7 @@ class MdnsServiceListenerFactory { static std::unique_ptr<ServiceListener> Create( const MdnsServiceListenerConfig& config, ServiceListener::Observer* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h index 4710d0fb437..075137a7243 100644 --- a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h @@ -11,9 +11,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -22,7 +20,7 @@ class MdnsServicePublisherFactory { static std::unique_ptr<ServicePublisher> Create( const ServicePublisher::Config& config, ServicePublisher::Observer* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc index cabc910fcfd..d67f848a140 100644 --- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc +++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc @@ -116,7 +116,7 @@ MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=( return *this; } -MessageDemuxer::MessageDemuxer(platform::ClockNowFunctionPtr now_function, +MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit = kDefaultBufferLimit) : now_function_(now_function), buffer_limit_(buffer_limit) { OSP_DCHECK(now_function_); diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h index 5592ee8d430..284e43ba8fa 100644 --- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h +++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h @@ -33,13 +33,12 @@ class MessageDemuxer { // error code of Error::Code::kCborIncompleteMessage. This way, // the MessageDemuxer knows to neither consume the data nor discard it as // bad. - virtual ErrorOr<size_t> OnStreamMessage( - uint64_t endpoint_id, - uint64_t connection_id, - msgs::Type message_type, - const uint8_t* buffer, - size_t buffer_size, - platform::Clock::time_point now) = 0; + virtual ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, + uint64_t connection_id, + msgs::Type message_type, + const uint8_t* buffer, + size_t buffer_size, + Clock::time_point now) = 0; }; class MessageWatch { @@ -64,8 +63,7 @@ class MessageDemuxer { static constexpr size_t kDefaultBufferLimit = 1 << 16; - MessageDemuxer(platform::ClockNowFunctionPtr now_function, - size_t buffer_limit); + MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit); ~MessageDemuxer(); // Starts watching for messages of type |message_type| from the endpoint @@ -110,7 +108,7 @@ class MessageDemuxer { std::map<msgs::Type, MessageCallback*>* message_callbacks, std::vector<uint8_t>* buffer); - const platform::ClockNowFunctionPtr now_function_; + const ClockNowFunctionPtr now_function_; const size_t buffer_limit_; std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_; std::map<msgs::Type, MessageCallback*> default_callbacks_; diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc b/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc index 7e830815d7a..63a03735770 100644 --- a/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc @@ -48,14 +48,12 @@ class MessageDemuxerTest : public ::testing::Test { const uint64_t endpoint_id_ = 13; const uint64_t connection_id_ = 45; - platform::FakeClock fake_clock_{ - platform::Clock::time_point(std::chrono::milliseconds(1298424))}; + FakeClock fake_clock_{Clock::time_point(std::chrono::milliseconds(1298424))}; msgs::CborEncodeBuffer buffer_; msgs::PresentationConnectionOpenRequest request_{1, "fry-am-the-egg-man", "url"}; MockMessageCallback mock_callback_; - MessageDemuxer demuxer_{platform::FakeClock::now, - MessageDemuxer::kDefaultBufferLimit}; + MessageDemuxer demuxer_{FakeClock::now, MessageDemuxer::kDefaultBufferLimit}; }; } // namespace @@ -75,15 +73,14 @@ TEST_F(MessageDemuxerTest, WatchStartStop) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size()); ExpectDecodedRequest(decode_result, received_request); @@ -110,15 +107,14 @@ TEST_F(MessageDemuxerTest, BufferPartialMessage) { OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) .Times(2) - .WillRepeatedly( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillRepeatedly(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size() - 3); demuxer_.OnStreamData(endpoint_id_, connection_id_, @@ -140,15 +136,14 @@ TEST_F(MessageDemuxerTest, DefaultWatch) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size()); ExpectDecodedRequest(decode_result, received_request); @@ -176,15 +171,14 @@ TEST_F(MessageDemuxerTest, DefaultWatchOverridden) { mock_callback_global, OnStreamMessage(endpoint_id_ + 1, 14, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size()); ExpectDecodedRequest(decode_result, received_request); @@ -193,15 +187,14 @@ TEST_F(MessageDemuxerTest, DefaultWatchOverridden) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size()); ExpectDecodedRequest(decode_result, received_request); @@ -214,15 +207,14 @@ TEST_F(MessageDemuxerTest, WatchAfterData) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType( endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_); @@ -245,27 +237,25 @@ TEST_F(MessageDemuxerTest, WatchAfterMultipleData) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result1, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result1 = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result1); - })); + .WillOnce(Invoke([&decode_result1, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result1 = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result1); + })); EXPECT_CALL(mock_init_callback, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationStartRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result2, &received_init_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result2 = msgs::DecodePresentationStartRequest( - buffer, buffer_size, &received_init_request); - return ConvertDecodeResult(decode_result2); - })); + .WillOnce(Invoke([&decode_result2, &received_init_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result2 = msgs::DecodePresentationStartRequest( + buffer, buffer_size, &received_init_request); + return ConvertDecodeResult(decode_result2); + })); MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType( endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_); @@ -296,15 +286,14 @@ TEST_F(MessageDemuxerTest, GlobalWatchAfterData) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch( msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_); ASSERT_TRUE(watch); @@ -314,7 +303,7 @@ TEST_F(MessageDemuxerTest, GlobalWatchAfterData) { } TEST_F(MessageDemuxerTest, BufferLimit) { - MessageDemuxer demuxer(platform::FakeClock::now, 10); + MessageDemuxer demuxer(FakeClock::now, 10); demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size()); @@ -329,15 +318,14 @@ TEST_F(MessageDemuxerTest, BufferLimit) { mock_callback_, OnStreamMessage(endpoint_id_, connection_id_, msgs::Type::kPresentationConnectionOpenRequest, _, _, _)) - .WillOnce( - Invoke([&decode_result, &received_request]( - uint64_t endpoint_id, uint64_t connection_id, - msgs::Type message_type, const uint8_t* buffer, - size_t buffer_size, platform::Clock::time_point now) { - decode_result = msgs::DecodePresentationConnectionOpenRequest( - buffer, buffer_size, &received_request); - return ConvertDecodeResult(decode_result); - })); + .WillOnce(Invoke([&decode_result, &received_request]( + uint64_t endpoint_id, uint64_t connection_id, + msgs::Type message_type, const uint8_t* buffer, + size_t buffer_size, Clock::time_point now) { + decode_result = msgs::DecodePresentationConnectionOpenRequest( + buffer, buffer_size, &received_request); + return ConvertDecodeResult(decode_result); + })); demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(), buffer_.size()); ExpectDecodedRequest(decode_result, received_request); diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h index 5382fec3eba..293e942ff4b 100644 --- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h +++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h @@ -196,7 +196,7 @@ class ConnectionManager final : public MessageDemuxer::MessageCallback { msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) override; + Clock::time_point now) override; Connection* GetConnection(uint64_t connection_id); diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h index 6d51329d1ba..d0a5ee8197c 100644 --- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h +++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h @@ -96,7 +96,7 @@ class Controller final : public ServiceListener::Observer, Controller* controller_; }; - explicit Controller(platform::ClockNowFunctionPtr now_function); + explicit Controller(ClockNowFunctionPtr now_function); ~Controller(); // Requests receivers compatible with all urls in |urls| and registers diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h index 3f87c0c5418..4eb4a04c26b 100644 --- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h +++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h @@ -97,7 +97,7 @@ class Receiver final : public MessageDemuxer::MessageCallback, msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now) override; + Clock::time_point now) override; private: struct QueuedResponse { diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h index 8672c16b7eb..e30d6216634 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h @@ -11,9 +11,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} // namespace platform namespace osp { @@ -22,7 +20,7 @@ class ProtocolConnectionClientFactory { static std::unique_ptr<ProtocolConnectionClient> Create( MessageDemuxer* demuxer, ProtocolConnectionServiceObserver* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h index 0f11b8a2752..0e55eeddc5e 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h @@ -12,9 +12,7 @@ namespace openscreen { -namespace platform { class TaskRunner; -} namespace osp { @@ -24,7 +22,7 @@ class ProtocolConnectionServerFactory { const ServerConfig& config, MessageDemuxer* demuxer, ProtocolConnectionServer::Observer* observer, - platform::TaskRunner* task_runner); + TaskRunner* task_runner); }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/osp/public/server_config.h b/chromium/third_party/openscreen/src/osp/public/server_config.h index 579fe36166e..e4f95d8a6c1 100644 --- a/chromium/third_party/openscreen/src/osp/public/server_config.h +++ b/chromium/third_party/openscreen/src/osp/public/server_config.h @@ -20,8 +20,8 @@ struct ServerConfig { // The indexes of network interfaces that should be used by the Open Screen // Library. The indexes derive from the values of - // openscreen::platform::InterfaceInfo::index. - std::vector<platform::NetworkInterfaceIndex> interface_indexes; + // openscreen::InterfaceInfo::index. + std::vector<NetworkInterfaceIndex> interface_indexes; // The list of connection endpoints that are advertised for Open Screen // protocol connections. These must be reachable via one interface in diff --git a/chromium/third_party/openscreen/src/osp/public/service_info.cc b/chromium/third_party/openscreen/src/osp/public/service_info.cc index ebd41b2779f..77943ca46e1 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_info.cc +++ b/chromium/third_party/openscreen/src/osp/public/service_info.cc @@ -23,11 +23,10 @@ bool ServiceInfo::operator!=(const ServiceInfo& other) const { return !(*this == other); } -bool ServiceInfo::Update( - std::string new_friendly_name, - platform::NetworkInterfaceIndex new_network_interface_index, - const IPEndpoint& new_v4_endpoint, - const IPEndpoint& new_v6_endpoint) { +bool ServiceInfo::Update(std::string new_friendly_name, + NetworkInterfaceIndex new_network_interface_index, + const IPEndpoint& new_v4_endpoint, + const IPEndpoint& new_v6_endpoint) { OSP_DCHECK(!new_v4_endpoint.address || IPAddress::Version::kV4 == new_v4_endpoint.address.version()); OSP_DCHECK(!new_v6_endpoint.address || diff --git a/chromium/third_party/openscreen/src/osp/public/service_info.h b/chromium/third_party/openscreen/src/osp/public/service_info.h index df9d8e3638d..73e43eb6303 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_info.h +++ b/chromium/third_party/openscreen/src/osp/public/service_info.h @@ -29,7 +29,7 @@ struct ServiceInfo { bool operator!=(const ServiceInfo& other) const; bool Update(std::string friendly_name, - platform::NetworkInterfaceIndex network_interface_index, + NetworkInterfaceIndex network_interface_index, const IPEndpoint& v4_endpoint, const IPEndpoint& v6_endpoint); @@ -40,8 +40,7 @@ struct ServiceInfo { std::string friendly_name; // The index of the network interface that the screen was discovered on. - platform::NetworkInterfaceIndex network_interface_index = - platform::kInvalidNetworkInterfaceIndex; + NetworkInterfaceIndex network_interface_index = kInvalidNetworkInterfaceIndex; // The network endpoints to create a new connection to the Open Screen // service. diff --git a/chromium/third_party/openscreen/src/osp/public/service_listener.h b/chromium/third_party/openscreen/src/osp/public/service_listener.h index 5d81536d429..2a11e44c42b 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_listener.h +++ b/chromium/third_party/openscreen/src/osp/public/service_listener.h @@ -5,7 +5,6 @@ #ifndef OSP_PUBLIC_SERVICE_LISTENER_H_ #define OSP_PUBLIC_SERVICE_LISTENER_H_ -#include <atomic> #include <cstdint> #include <string> #include <vector> @@ -150,7 +149,7 @@ class ServiceListener { protected: ServiceListener(); - std::atomic<State> state_; + State state_; ServiceListenerError last_error_; std::vector<Observer*> observers_; diff --git a/chromium/third_party/openscreen/src/osp/public/service_publisher.h b/chromium/third_party/openscreen/src/osp/public/service_publisher.h index a5fac6a23fe..b31f59fc066 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_publisher.h +++ b/chromium/third_party/openscreen/src/osp/public/service_publisher.h @@ -5,7 +5,6 @@ #ifndef OSP_PUBLIC_SERVICE_PUBLISHER_H_ #define OSP_PUBLIC_SERVICE_PUBLISHER_H_ -#include <atomic> #include <cstdint> #include <string> #include <vector> @@ -108,7 +107,7 @@ class ServicePublisher { // By default, all enabled Ethernet and WiFi interfaces are used. // This configuration must be identical to the interfaces configured // in the ScreenConnectionServer. - std::vector<platform::NetworkInterfaceIndex> network_interface_indices; + std::vector<NetworkInterfaceIndex> network_interface_indices; }; virtual ~ServicePublisher(); @@ -145,7 +144,7 @@ class ServicePublisher { protected: explicit ServicePublisher(Observer* observer); - std::atomic<State> state_; + State state_; ServicePublisherError last_error_; Observer* observer_; diff --git a/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h b/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h index 7ba1e36b4ff..cde458c22f2 100644 --- a/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h +++ b/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h @@ -22,7 +22,7 @@ class MockMessageCallback final : public MessageDemuxer::MessageCallback { msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, - platform::Clock::time_point now)); + Clock::time_point now)); }; } // namespace osp diff --git a/chromium/third_party/openscreen/src/platform/BUILD.gn b/chromium/third_party/openscreen/src/platform/BUILD.gn index bd24e0b0090..9545232cf87 100644 --- a/chromium/third_party/openscreen/src/platform/BUILD.gn +++ b/chromium/third_party/openscreen/src/platform/BUILD.gn @@ -4,26 +4,11 @@ import("//build_overrides/build.gni") -source_set("platform") { +# Source files that depend on nothing (all your base/ are belong to us). +source_set("base") { defines = [] sources = [ - "api/logging.h", - "api/network_interface.h", - "api/scoped_wake_lock.cc", - "api/scoped_wake_lock.h", - "api/socket.h", - "api/task_runner.h", - "api/time.h", - "api/tls_connection.cc", - "api/tls_connection.h", - "api/tls_connection_factory.cc", - "api/tls_connection_factory.h", - "api/trace_logging_platform.cc", - "api/trace_logging_platform.h", - "api/udp_read_callback.h", - "api/udp_socket.cc", - "api/udp_socket.h", "base/error.cc", "base/error.h", "base/interface_info.cc", @@ -46,20 +31,49 @@ source_set("platform") { "base/udp_packet.h", ] + public_configs = [ "../build:openscreen_include_dirs" ] +} + +# Public API source files. These may depend on nothing except :base. +source_set("api") { + defines = [] + + sources = [ + "api/export.h", + "api/logging.h", + "api/network_interface.h", + "api/scoped_wake_lock.cc", + "api/scoped_wake_lock.h", + "api/task_runner.h", + "api/time.h", + "api/tls_connection.cc", + "api/tls_connection.h", + "api/tls_connection_factory.cc", + "api/tls_connection_factory.h", + "api/trace_logging_platform.cc", + "api/trace_logging_platform.h", + "api/udp_socket.cc", + "api/udp_socket.h", + ] + public_deps = [ - "../third_party/abseil", - "../third_party/boringssl", - "../util", + ":base", ] +} - public_configs = [ "../build:openscreen_include_dirs" ] +# The following target is only activated in standalone builds (see :platform). +if (!build_with_chromium) { + source_set("standalone_impl") { + defines = [] - if (!build_with_chromium) { - sources += [ + sources = [ "impl/logging.h", + "impl/network_interface.cc", + "impl/network_interface.h", "impl/socket_handle.h", "impl/socket_handle_waiter.cc", "impl/socket_handle_waiter.h", + "impl/stream_socket.h", "impl/task_runner.cc", "impl/task_runner.h", "impl/text_trace_logging_platform.cc", @@ -67,11 +81,14 @@ source_set("platform") { "impl/time.cc", "impl/tls_write_buffer.cc", "impl/tls_write_buffer.h", - "impl/weak_ptr.h", ] if (is_linux) { - sources += [ "impl/network_interface_linux.cc" ] + sources += [ + "impl/network_interface_linux.cc", + "impl/scoped_wake_lock_linux.cc", + "impl/scoped_wake_lock_linux.h", + ] } else if (is_mac) { defines += [ # Required, to use the new IPv6 Sockets options introduced by RFC 3542. @@ -102,7 +119,6 @@ source_set("platform") { "impl/socket_handle_posix.h", "impl/socket_handle_waiter_posix.cc", "impl/socket_handle_waiter_posix.h", - "impl/stream_socket.h", "impl/stream_socket_posix.cc", "impl/stream_socket_posix.h", "impl/timeval_posix.cc", @@ -119,9 +135,30 @@ source_set("platform") { "impl/udp_socket_reader_posix.h", ] } + + deps = [ + ":api", + "../third_party/abseil", + "../third_party/boringssl", + "../util", + ] } } +# The main target, which either assumes an embedder will link-in the platform +# API implementation elsewhere, or links-in the :standalone_impl in the build. +source_set("platform") { + public_deps = [ + ":api", + ] + if (!build_with_chromium) { + deps = [ + ":standalone_impl", + ] + } +} + +# Test helpers, referenced in other Open Screen BUILD.gn test targets. source_set("test") { testonly = true sources = [ @@ -132,12 +169,15 @@ source_set("test") { "test/fake_udp_socket.cc", "test/fake_udp_socket.h", "test/mock_tls_connection.h", + "test/mock_udp_socket.h", "test/trace_logging_helpers.h", ] deps = [ ":platform", + "../third_party/abseil", "../third_party/googletest:gmock", + "../util", ] } @@ -145,6 +185,7 @@ source_set("unittests") { testonly = true sources = [ + "api/socket_integration_unittest.cc", "api/time_unittest.cc", "base/error_unittest.cc", "base/ip_address_unittest.cc", @@ -155,12 +196,8 @@ source_set("unittests") { # Exclude them if an embedder is providing the implementation. if (!build_with_chromium) { sources += [ - # TODO(jophba): move over to general sources when UDP socket create - # is implemented in Chromium, as part of the NetworkRunner work. - "api/socket_integration_unittest.cc", "impl/task_runner_unittest.cc", "impl/time_unittest.cc", - "impl/weak_ptr_unittest.cc", ] if (is_posix) { @@ -178,7 +215,11 @@ source_set("unittests") { deps = [ ":platform", + ":test", + "../third_party/abseil", + "../third_party/boringssl", "../third_party/googletest:gmock", "../third_party/googletest:gtest", + "../util", ] } diff --git a/chromium/third_party/openscreen/src/platform/api/DEPS b/chromium/third_party/openscreen/src/platform/api/DEPS index f0fc67f29bb..f833653aeb9 100644 --- a/chromium/third_party/openscreen/src/platform/api/DEPS +++ b/chromium/third_party/openscreen/src/platform/api/DEPS @@ -3,7 +3,20 @@ # found in the LICENSE file. include_rules = [ + # Platform API code should depend on no outside code/libraries, other than + # the standard toolchain libraries (C, STL) and platform/base. + '-absl', + '-platform', + '+platform/api', '+platform/base', '-util', '-third_party', ] + +specific_include_rules = { + ".*_unittest\.cc": [ + '+platform/test', + '+util', + '+third_party', + ], +} diff --git a/chromium/third_party/openscreen/src/platform/api/logging.h b/chromium/third_party/openscreen/src/platform/api/logging.h index 9ff79d1a75a..6080b77c8b2 100644 --- a/chromium/third_party/openscreen/src/platform/api/logging.h +++ b/chromium/third_party/openscreen/src/platform/api/logging.h @@ -5,10 +5,9 @@ #ifndef PLATFORM_API_LOGGING_H_ #define PLATFORM_API_LOGGING_H_ -#include "absl/strings/string_view.h" +#include <sstream> namespace openscreen { -namespace platform { enum class LogLevel { // Very detailed information, often used for evaluating performance or @@ -36,17 +35,20 @@ enum class LogLevel { // Returns true if |level| is at or above the level where the embedder will // record/emit log entries from the code in |file|. -bool IsLoggingOn(LogLevel level, absl::string_view file); +bool IsLoggingOn(LogLevel level, const char* file); // Record a log entry, consisting of its logging level, location and message. // The embedder may filter-out entries according to its own policy, but this // function will not be called if IsLoggingOn(level, file) returns false. // Whenever |level| is kFatal, Open Screen will call Break() immediately after // this returns. +// +// |message| is passed as a string stream to avoid unnecessary string copies. +// Embedders can call its rdbuf() or str() methods to access the log message. void LogWithLevel(LogLevel level, - absl::string_view file, + const char* file, int line, - absl::string_view msg); + std::stringstream message); // Breaks into the debugger, if one is present. Otherwise, aborts the current // process (i.e., this function should not return). In production builds, an @@ -55,7 +57,6 @@ void LogWithLevel(LogLevel level, // aborting the process. void Break(); -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_LOGGING_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/network_interface.h b/chromium/third_party/openscreen/src/platform/api/network_interface.h index f0e53e17b11..28ebd2fed11 100644 --- a/chromium/third_party/openscreen/src/platform/api/network_interface.h +++ b/chromium/third_party/openscreen/src/platform/api/network_interface.h @@ -10,7 +10,6 @@ #include "platform/base/interface_info.h" namespace openscreen { -namespace platform { // Returns an InterfaceInfo for each currently active network interface on the // system. No two entries in this vector can have the same NetworkInterfaceIndex @@ -22,7 +21,6 @@ namespace platform { // discovery) are not being used. std::vector<InterfaceInfo> GetNetworkInterfaces(); -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_NETWORK_INTERFACE_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc index ad8442e091f..73a946e1e9a 100644 --- a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc +++ b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc @@ -5,10 +5,8 @@ #include "platform/api/scoped_wake_lock.h" namespace openscreen { -namespace platform { ScopedWakeLock::ScopedWakeLock() = default; ScopedWakeLock::~ScopedWakeLock() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h index 64d69ecb867..2843be4eea5 100644 --- a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h +++ b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h @@ -8,7 +8,6 @@ #include <memory> namespace openscreen { -namespace platform { // Ensures that the device does not got to sleep. This is used, for example, // while Open Screen is communicating with peers over the network for things @@ -33,7 +32,6 @@ class ScopedWakeLock { virtual ~ScopedWakeLock(); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_SCOPED_WAKE_LOCK_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc index fe15377ba56..3704dbcbee3 100644 --- a/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc @@ -9,7 +9,6 @@ #include "platform/test/fake_udp_socket.h" namespace openscreen { -namespace platform { using testing::_; @@ -21,7 +20,7 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) { FakeClock clock(Clock::now()); FakeTaskRunner task_runner(&clock); FakeUdpSocket::MockClient client; - ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create( + ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create( &task_runner, &client, IPEndpoint{IPAddress(kIpV4AddrAny), 0}); ASSERT_TRUE(create_result) << create_result.error(); const auto socket = std::move(create_result.value()); @@ -35,11 +34,11 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) { // successfully Bind(), and that the operating system will return the // auto-assigned socket name (i.e., the local endpoint's port will not be zero). TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) { - const uint8_t kIpV6AddrAny[16] = {}; + const uint16_t kIpV6AddrAny[8] = {}; FakeClock clock(Clock::now()); FakeTaskRunner task_runner(&clock); FakeUdpSocket::MockClient client; - ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create( + ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create( &task_runner, &client, IPEndpoint{IPAddress(kIpV6AddrAny), 0}); ASSERT_TRUE(create_result) << create_result.error(); const auto socket = std::move(create_result.value()); @@ -49,5 +48,4 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) { EXPECT_NE(local_endpoint.port, 0) << local_endpoint; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/task_runner.h b/chromium/third_party/openscreen/src/platform/api/task_runner.h index c94777fd0a6..e114db68860 100644 --- a/chromium/third_party/openscreen/src/platform/api/task_runner.h +++ b/chromium/third_party/openscreen/src/platform/api/task_runner.h @@ -5,21 +5,21 @@ #ifndef PLATFORM_API_TASK_RUNNER_H_ #define PLATFORM_API_TASK_RUNNER_H_ -#include <future> +#include <future> // NOLINT +#include <utility> -#include "absl/types/optional.h" #include "platform/api/time.h" namespace openscreen { -namespace platform { // A thread-safe API surface that allows for posting tasks. The underlying // implementation may be single or multi-threaded, and all complication should -// be handled by the implementation class. It is the expectation of this API -// that the underlying impl gives the following guarantees: +// be handled by the implementation class. The implementation must guarantee: // (1) Tasks shall not overlap in time/CPU. // (2) Tasks shall run sequentially, e.g. posting task A then B implies // that A shall run before B. +// (3) If task A is posted before task B, then any mutation in A happens-before +// B runs (even if A and B run on different threads). class TaskRunner { public: using Task = std::packaged_task<void() noexcept>; @@ -50,10 +50,9 @@ class TaskRunner { // Return true if the calling thread is the thread that task runner is using // to run tasks, false otherwise. - virtual bool IsRunningOnTaskRunner() { return true; } + virtual bool IsRunningOnTaskRunner() = 0; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_TASK_RUNNER_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/time.h b/chromium/third_party/openscreen/src/platform/api/time.h index af0dda7573d..dcf24accc88 100644 --- a/chromium/third_party/openscreen/src/platform/api/time.h +++ b/chromium/third_party/openscreen/src/platform/api/time.h @@ -10,7 +10,6 @@ #include "platform/base/trivial_clock_traits.h" namespace openscreen { -namespace platform { // The "reasonably high-resolution" source of monotonic time from the embedder, // exhibiting the traits described in TrivialClockTraits. This class is not @@ -33,7 +32,6 @@ class Clock : public TrivialClockTraits { // time." std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_TIME_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc index 9a4f40cbd10..2bef8250f04 100644 --- a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc @@ -12,7 +12,6 @@ using std::chrono::microseconds; using std::chrono::milliseconds; namespace openscreen { -namespace platform { namespace { // Tests that the clock always seems to tick forward. If this test is broken, or @@ -56,5 +55,4 @@ TEST(TimeTest, PlatformClockHasSufficientResolution) { } } // namespace -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection.cc b/chromium/third_party/openscreen/src/platform/api/tls_connection.cc index fc3a8c8ebef..9668c114512 100644 --- a/chromium/third_party/openscreen/src/platform/api/tls_connection.cc +++ b/chromium/third_party/openscreen/src/platform/api/tls_connection.cc @@ -5,10 +5,8 @@ #include "platform/api/tls_connection.h" namespace openscreen { -namespace platform { TlsConnection::TlsConnection() = default; TlsConnection::~TlsConnection() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection.h b/chromium/third_party/openscreen/src/platform/api/tls_connection.h index 7c2777fb204..4d409cb54f9 100644 --- a/chromium/third_party/openscreen/src/platform/api/tls_connection.h +++ b/chromium/third_party/openscreen/src/platform/api/tls_connection.h @@ -12,20 +12,12 @@ #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { class TlsConnection { public: // Client callbacks are run via the TaskRunner used by TlsConnectionFactory. class Client { public: - // Called when |connection| writing is blocked and unblocked, respectively. - // Note that implementations should do best effort to buffer packets even in - // blocked state, and should call OnError if we actually overflow the - // buffer. - virtual void OnWriteBlocked(TlsConnection* connection) = 0; - virtual void OnWriteUnblocked(TlsConnection* connection) = 0; - // Called when |connection| experiences an error, such as a read error. virtual void OnError(TlsConnection* connection, Error error) = 0; @@ -45,8 +37,8 @@ class TlsConnection { // the Client. virtual void SetClient(Client* client) = 0; - // Sends a message. - virtual void Write(const void* data, size_t len) = 0; + // Sends a message. Returns true iff the message will be sent. + [[nodiscard]] virtual bool Send(const void* data, size_t len) = 0; // Get the local address. virtual IPEndpoint GetLocalEndpoint() const = 0; @@ -58,7 +50,6 @@ class TlsConnection { TlsConnection(); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_TLS_CONNECTION_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc index a3bfb287962..e64078f1727 100644 --- a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc +++ b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc @@ -5,10 +5,8 @@ #include "platform/api/tls_connection_factory.h" namespace openscreen { -namespace platform { TlsConnectionFactory::TlsConnectionFactory() = default; TlsConnectionFactory::~TlsConnectionFactory() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h index eed6893f404..80dc8ac6367 100644 --- a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h +++ b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h @@ -13,7 +13,6 @@ #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { class TaskRunner; class TlsConnection; @@ -30,7 +29,7 @@ class TlsConnectionFactory { public: // Provides a new |connection| that resulted from listening on the local // socket. |der_x509_peer_cert| is the DER-encoded X509 certificate from the - // peer. + // peer if present, or empty if the peer didn't provide one. virtual void OnAccepted(TlsConnectionFactory* factory, std::vector<uint8_t> der_x509_peer_cert, std::unique_ptr<TlsConnection> connection) = 0; @@ -75,7 +74,6 @@ class TlsConnectionFactory { TlsConnectionFactory(); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_TLS_CONNECTION_FACTORY_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc index 2725b404958..35be2bf329b 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc @@ -5,9 +5,7 @@ #include "platform/api/trace_logging_platform.h" namespace openscreen { -namespace platform { TraceLoggingPlatform::~TraceLoggingPlatform() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h index 76c7e7a974d..882c1a7db6c 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h @@ -11,12 +11,13 @@ #include "platform/base/trace_logging_types.h" namespace openscreen { -namespace platform { // Optional platform API to support logging trace events from Open Screen. To // use this, implement the TraceLoggingPlatform interface and call // StartTracing() and StopTracing() to turn tracing on/off (see // platform/base/trace_logging_activation.h). +// +// All methods must be thread-safe and re-entrant. class TraceLoggingPlatform { public: virtual ~TraceLoggingPlatform(); @@ -48,7 +49,6 @@ class TraceLoggingPlatform { Error::Code error) = 0; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_TRACE_LOGGING_PLATFORM_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc index 60262cab8c4..2da7a73f429 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc +++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc @@ -5,10 +5,8 @@ #include "platform/api/udp_socket.h" namespace openscreen { -namespace platform { UdpSocket::UdpSocket() = default; UdpSocket::~UdpSocket() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.h b/chromium/third_party/openscreen/src/platform/api/udp_socket.h index e76b9eb8586..e668db12a50 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_socket.h +++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.h @@ -5,11 +5,10 @@ #ifndef PLATFORM_API_UDP_SOCKET_H_ #define PLATFORM_API_UDP_SOCKET_H_ -#include <atomic> -#include <cstdint> -#include <functional> +#include <stddef.h> // size_t +#include <stdint.h> // uint8_t + #include <memory> -#include <mutex> #include "platform/api/network_interface.h" #include "platform/base/error.h" @@ -17,12 +16,8 @@ #include "platform/base/udp_packet.h" namespace openscreen { -namespace platform { class TaskRunner; -class UdpSocket; - -using UdpSocketUniquePtr = std::unique_ptr<UdpSocket>; // An open UDP socket for sending/receiving datagrams to/from either specific // endpoints or over IP multicast. @@ -75,9 +70,10 @@ class UdpSocket { // will be queued on the provided |task_runner|. For this reason, the provided // TaskRunner and Client must exist for the duration of the created socket's // lifetime. - static ErrorOr<UdpSocketUniquePtr> Create(TaskRunner* task_runner, - Client* client, - const IPEndpoint& local_endpoint); + static ErrorOr<std::unique_ptr<UdpSocket>> Create( + TaskRunner* task_runner, + Client* client, + const IPEndpoint& local_endpoint); virtual ~UdpSocket(); @@ -119,7 +115,6 @@ class UdpSocket { UdpSocket(); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_API_UDP_SOCKET_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/DEPS b/chromium/third_party/openscreen/src/platform/base/DEPS index 46dad90c1db..f4c412d5eec 100644 --- a/chromium/third_party/openscreen/src/platform/base/DEPS +++ b/chromium/third_party/openscreen/src/platform/base/DEPS @@ -3,13 +3,18 @@ # found in the LICENSE file. include_rules = [ - '-platform' + # Platform base code should depend on no outside code/libraries, other than + # the standard toolchain libraries (C, STL). + '-absl', + '-platform', + '+platform/base', '-util', '-third_party', ] specific_include_rules = { ".*_unittest\.cc": [ + '+platform/test', '+util', '+third_party', ], diff --git a/chromium/third_party/openscreen/src/platform/base/error.cc b/chromium/third_party/openscreen/src/platform/base/error.cc index fd300c1abb3..ea58c91ba49 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.cc +++ b/chromium/third_party/openscreen/src/platform/base/error.cc @@ -4,6 +4,8 @@ #include "platform/base/error.h" +#include <sstream> + namespace openscreen { Error::Error() = default; @@ -120,6 +122,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: FatalSSLError"; case Error::Code::kRSAKeyGenerationFailure: return os << "Failure: RSAKeyGenerationFailure"; + case Error::Code::kRSAKeyParseError: + return os << "Failure: RSAKeyParseError"; case Error::Code::kEVPInitializationError: return os << "Failure: EVPInitializationError"; case Error::Code::kCertificateCreationError: @@ -166,6 +170,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: ItemNotFound"; case Error::Code::kOperationInvalid: return os << "Failure: OperationInvalid"; + case Error::Code::kOperationInProgress: + return os << "Failure: OperationInProgress"; case Error::Code::kOperationCancelled: return os << "Failure: OperationCancelled"; case Error::Code::kCastV2PeerCertEmpty: @@ -220,12 +226,24 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: kCastV2PingTimeout"; case Error::Code::kCastV2ChannelPolicyMismatch: return os << "Failure: kCastV2ChannelPolicyMismatch"; + case Error::Code::kCreateSignatureFailed: + return os << "Failure: kCreateSignatureFailed"; + case Error::Code::kUpdateReceivedRecordFailure: + return os << "Failure: kUpdateReceivedRecordFailure"; + case Error::Code::kRecordPublicationError: + return os << "Failure: kRecordPublicationError"; } // Unused 'return' to get around failure on GCC. return os; } +std::string Error::ToString() const { + std::stringstream ss; + ss << *this; + return ss.str(); +} + std::ostream& operator<<(std::ostream& out, const Error& error) { out << error.code() << " = \"" << error.message() << "\""; return out; diff --git a/chromium/third_party/openscreen/src/platform/base/error.h b/chromium/third_party/openscreen/src/platform/base/error.h index aedbf29a2a0..aa79a81567c 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.h +++ b/chromium/third_party/openscreen/src/platform/base/error.h @@ -5,11 +5,11 @@ #ifndef PLATFORM_BASE_ERROR_H_ #define PLATFORM_BASE_ERROR_H_ +#include <cassert> #include <ostream> #include <string> #include <utility> -#include "absl/types/variant.h" #include "platform/base/macros.h" namespace openscreen { @@ -81,6 +81,7 @@ class Error { // Was unable to generate an RSA key. kRSAKeyGenerationFailure, + kRSAKeyParseError, // Was unable to initialize an EVP_PKEY type. kEVPInitializationError, @@ -155,6 +156,12 @@ class Error { kCastV2PingTimeout, kCastV2ChannelPolicyMismatch, + kCreateSignatureFailed, + + // Discovery errors. + kUpdateReceivedRecordFailure, + kRecordPublicationError, + // Generic errors. kUnknownError, kNotImplemented, @@ -166,6 +173,7 @@ class Error { kItemAlreadyExists, kItemNotFound, kOperationInvalid, + kOperationInProgress, kOperationCancelled, }; @@ -186,9 +194,12 @@ class Error { Code code() const { return code_; } const std::string& message() const { return message_; } + std::string& message() { return message_; } static const Error& None(); + std::string ToString() const; + private: Code code_ = Code::kNone; std::string message_; @@ -222,35 +233,89 @@ class ErrorOr { return error; } - ErrorOr(ErrorOr&& other) = default; - ErrorOr& operator=(ErrorOr&& other) = default; + ErrorOr(const ValueType& value) : value_(value), is_value_(true) {} // NOLINT + ErrorOr(ValueType&& value) noexcept // NOLINT + : value_(std::move(value)), is_value_(true) {} - ErrorOr(const ValueType& value) : variant_{value} {} // NOLINT - ErrorOr(ValueType&& value) noexcept // NOLINT - : variant_{std::move(value)} {} - - ErrorOr(Error error) : variant_{std::move(error)} {} // NOLINT - ErrorOr(Error::Code code) : variant_{code} {} // NOLINT + ErrorOr(const Error& error) : error_(error), is_value_(false) { // NOLINT + assert(error_.code() != Error::Code::kNone); + } + ErrorOr(Error&& error) noexcept // NOLINT + : error_(std::move(error)), is_value_(false) { + assert(error_.code() != Error::Code::kNone); + } + ErrorOr(Error::Code code) : error_(code), is_value_(false) { // NOLINT + assert(error_.code() != Error::Code::kNone); + } ErrorOr(Error::Code code, std::string message) - : variant_{Error{code, std::move(message)}} {} + : error_(code, std::move(message)), is_value_(false) { + assert(error_.code() != Error::Code::kNone); + } - ~ErrorOr() = default; + ErrorOr(ErrorOr&& other) noexcept : is_value_(other.is_value_) { + // NB: Both |value_| and |error_| are uninitialized memory at this point! + // Unlike the other constructors, the compiler will not auto-generate + // constructor calls for either union member because neither appeared in + // this constructor's initializer list. + if (other.is_value_) { + new (&value_) ValueType(std::move(other.value_)); + } else { + new (&error_) Error(std::move(other.error_)); + } + } - bool is_error() const { return absl::holds_alternative<Error>(variant_); } - bool is_value() const { return !is_error(); } + ErrorOr& operator=(ErrorOr&& other) noexcept { + this->~ErrorOr<ValueType>(); + new (this) ErrorOr<ValueType>(std::move(other)); + return *this; + } + + ~ErrorOr() { + // NB: |value_| or |error_| must be explicitly destroyed since the compiler + // will not auto-generate the destructor calls for union members. + if (is_value_) { + value_.~ValueType(); + } else { + error_.~Error(); + } + } + + bool is_error() const { return !is_value_; } + bool is_value() const { return is_value_; } // Unlike Error, we CAN provide an operator bool here, since it is // more obvious to callers that ErrorOr<Foo> will be true if it's Foo. - operator bool() const { return is_value(); } + operator bool() const { return is_value_; } - const Error& error() const { return absl::get<Error>(variant_); } - Error& error() { return absl::get<Error>(variant_); } + const Error& error() const { + assert(!is_value_); + return error_; + } + Error& error() { + assert(!is_value_); + return error_; + } - const ValueType& value() const { return absl::get<ValueType>(variant_); } - ValueType& value() { return absl::get<ValueType>(variant_); } + const ValueType& value() const { + assert(is_value_); + return value_; + } + ValueType& value() { + assert(is_value_); + return value_; + } private: - absl::variant<Error, ValueType> variant_; + // Only one of these is an active member, determined by |is_value_|. Since + // they are union'ed, they must be explicitly constructed and destroyed. + union { + ValueType value_; + Error error_; + }; + + // If true, |value_| is initialized and active. Otherwise, |error_| is + // initialized and active. + const bool is_value_; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/interface_info.cc b/chromium/third_party/openscreen/src/platform/base/interface_info.cc index b8baff97672..2ada91be943 100644 --- a/chromium/third_party/openscreen/src/platform/base/interface_info.cc +++ b/chromium/third_party/openscreen/src/platform/base/interface_info.cc @@ -4,8 +4,9 @@ #include "platform/base/interface_info.h" +#include <algorithm> + namespace openscreen { -namespace platform { InterfaceInfo::InterfaceInfo() = default; InterfaceInfo::InterfaceInfo(NetworkInterfaceIndex index, @@ -27,6 +28,24 @@ IPSubnet::IPSubnet(IPAddress address, uint8_t prefix_length) : address(std::move(address)), prefix_length(prefix_length) {} IPSubnet::~IPSubnet() = default; +IPAddress InterfaceInfo::GetIpAddressV4() const { + for (const auto& address : addresses) { + if (address.address.IsV4()) { + return address.address; + } + } + return IPAddress{}; +} + +IPAddress InterfaceInfo::GetIpAddressV6() const { + for (const auto& address : addresses) { + if (address.address.IsV6()) { + return address.address; + } + } + return IPAddress{}; +} + std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet) { if (subnet.address.IsV6()) { out << '['; @@ -38,23 +57,30 @@ std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet) { return out << '/' << std::dec << static_cast<int>(subnet.prefix_length); } -std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) { - std::string media_type; - switch (info.type) { +std::ostream& operator<<(std::ostream& out, InterfaceInfo::Type type) { + switch (type) { case InterfaceInfo::Type::kEthernet: - media_type = "Ethernet"; + out << "Ethernet"; break; case InterfaceInfo::Type::kWifi: - media_type = "Wifi"; + out << "Wifi"; + break; + case InterfaceInfo::Type::kLoopback: + out << "Loopback"; break; case InterfaceInfo::Type::kOther: - media_type = "Other"; + out << "Other"; break; } + + return out; +} + +std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) { out << '{' << info.index << " (a.k.a. " << info.name - << "); media_type=" << media_type << "; MAC=" << std::hex + << "); media_type=" << info.type << "; MAC=" << std::hex << static_cast<int>(info.hardware_address[0]); - for (size_t i = 1; i < sizeof(info.hardware_address); ++i) { + for (size_t i = 1; i < info.hardware_address.size(); ++i) { out << ':' << static_cast<int>(info.hardware_address[i]); } for (const IPSubnet& ip : info.addresses) { @@ -63,5 +89,4 @@ std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) { return out << '}'; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/interface_info.h b/chromium/third_party/openscreen/src/platform/base/interface_info.h index af0c35b6caa..81686063ba7 100644 --- a/chromium/third_party/openscreen/src/platform/base/interface_info.h +++ b/chromium/third_party/openscreen/src/platform/base/interface_info.h @@ -13,7 +13,6 @@ #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { // Unique identifier, usually provided by the operating system, for identifying // a specific network interface. This value is used with UdpSocket to join @@ -40,18 +39,14 @@ struct IPSubnet { }; struct InterfaceInfo { - enum class Type { - kEthernet = 0, - kWifi, - kOther, - }; + enum class Type : uint32_t { kEthernet = 0, kWifi, kLoopback, kOther }; // Interface index, typically as specified by the operating system, // identifying this interface on the host machine. NetworkInterfaceIndex index = kInvalidNetworkInterfaceIndex; // MAC address of the interface. All 0s if unavailable. - uint8_t hardware_address[6] = {}; + std::array<uint8_t, 6> hardware_address = {}; // Interface name (e.g. eth0) if available. std::string name; @@ -62,6 +57,12 @@ struct InterfaceInfo { // All IP addresses associated with the interface. std::vector<IPSubnet> addresses; + // Returns an IPAddress of the given type associated with this network + // interface, or the false IPAddress if the associated address family is not + // supported on this interface. + IPAddress GetIpAddressV4() const; + IPAddress GetIpAddressV6() const; + InterfaceInfo(); InterfaceInfo(NetworkInterfaceIndex index, const uint8_t hardware_address[6], @@ -72,10 +73,10 @@ struct InterfaceInfo { }; // Human-readable output (e.g., for logging). +std::ostream& operator<<(std::ostream& out, InterfaceInfo::Type type); std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet); std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info); -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_INTERFACE_INFO_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.cc b/chromium/third_party/openscreen/src/platform/base/ip_address.cc index 808b40be7e6..6fee6a3c9c1 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address.cc +++ b/chromium/third_party/openscreen/src/platform/base/ip_address.cc @@ -4,20 +4,24 @@ #include "platform/base/ip_address.h" +#include <algorithm> #include <cassert> +#include <cctype> +#include <cinttypes> +#include <cstdio> #include <cstring> #include <iomanip> - -#include "absl/types/optional.h" +#include <iterator> +#include <sstream> +#include <utility> namespace openscreen { // static -ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) { - ErrorOr<IPAddress> v4 = ParseV4(s); +const IPAddress IPAddress::kV4LoopbackAddress{127, 0, 0, 1}; - return v4 ? std::move(v4) : ParseV6(s); -} // namespace openscreen +// static +const IPAddress IPAddress::kV6LoopbackAddress{0, 0, 0, 0, 0, 0, 0, 1}; IPAddress::IPAddress() : version_(Version::kV4), bytes_({}) {} IPAddress::IPAddress(const std::array<uint8_t, 4>& bytes) @@ -35,31 +39,55 @@ IPAddress::IPAddress(Version version, const uint8_t* b) : version_(version) { } IPAddress::IPAddress(uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4) : version_(Version::kV4), bytes_{{b1, b2, b3, b4}} {} -IPAddress::IPAddress(const std::array<uint8_t, 16>& bytes) - : version_(Version::kV6), bytes_(bytes) {} -IPAddress::IPAddress(const uint8_t (&b)[16]) - : version_(Version::kV6), - bytes_{{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], - b[11], b[12], b[13], b[14], b[15]}} {} -IPAddress::IPAddress(uint8_t b1, - uint8_t b2, - uint8_t b3, - uint8_t b4, - uint8_t b5, - uint8_t b6, - uint8_t b7, - uint8_t b8, - uint8_t b9, - uint8_t b10, - uint8_t b11, - uint8_t b12, - uint8_t b13, - uint8_t b14, - uint8_t b15, - uint8_t b16) + +IPAddress::IPAddress(const std::array<uint16_t, 8>& hextets) + : IPAddress(hextets[0], + hextets[1], + hextets[2], + hextets[3], + hextets[4], + hextets[5], + hextets[6], + hextets[7]) {} + +IPAddress::IPAddress(const uint16_t (&hextets)[8]) + : IPAddress(hextets[0], + hextets[1], + hextets[2], + hextets[3], + hextets[4], + hextets[5], + hextets[6], + hextets[7]) {} + +IPAddress::IPAddress(uint16_t h0, + uint16_t h1, + uint16_t h2, + uint16_t h3, + uint16_t h4, + uint16_t h5, + uint16_t h6, + uint16_t h7) : version_(Version::kV6), - bytes_{{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, - b16}} {} + bytes_{{ + static_cast<uint8_t>(h0 >> 8), + static_cast<uint8_t>(h0), + static_cast<uint8_t>(h1 >> 8), + static_cast<uint8_t>(h1), + static_cast<uint8_t>(h2 >> 8), + static_cast<uint8_t>(h2), + static_cast<uint8_t>(h3 >> 8), + static_cast<uint8_t>(h3), + static_cast<uint8_t>(h4 >> 8), + static_cast<uint8_t>(h4), + static_cast<uint8_t>(h5 >> 8), + static_cast<uint8_t>(h5), + static_cast<uint8_t>(h6 >> 8), + static_cast<uint8_t>(h6), + static_cast<uint8_t>(h7 >> 8), + static_cast<uint8_t>(h7), + }} {} + IPAddress::IPAddress(const IPAddress& o) noexcept = default; IPAddress::IPAddress(IPAddress&& o) noexcept = default; IPAddress& IPAddress::operator=(const IPAddress& o) noexcept = default; @@ -101,115 +129,136 @@ void IPAddress::CopyToV6(uint8_t x[16]) const { std::memcpy(x, bytes_.data(), 16); } -// static -ErrorOr<IPAddress> IPAddress::ParseV4(const std::string& s) { - if (s.size() > 0 && s[0] == '.') +namespace { + +ErrorOr<IPAddress> ParseV4(const std::string& s) { + int octets[4]; + int chars_scanned; + // Note: sscanf()'s parsing for %d allows leading whitespace; so the invalid + // presence of whitespace must be explicitly checked too. + if (std::any_of(s.begin(), s.end(), [](char c) { return std::isspace(c); }) || + sscanf(s.c_str(), "%3d.%3d.%3d.%3d%n", &octets[0], &octets[1], &octets[2], + &octets[3], &chars_scanned) != 4 || + chars_scanned != static_cast<int>(s.size()) || + std::any_of(std::begin(octets), std::end(octets), + [](int octet) { return octet < 0 || octet > 255; })) { return Error::Code::kInvalidIPV4Address; + } + return IPAddress(octets[0], octets[1], octets[2], octets[3]); +} - IPAddress address; - uint16_t next_octet = 0; - int i = 0; - bool previous_dot = false; - for (auto c : s) { - if (c == '.') { - if (previous_dot) { - return Error::Code::kInvalidIPV4Address; - } - address.bytes_[i++] = static_cast<uint8_t>(next_octet); - next_octet = 0; - previous_dot = true; - if (i > 3) - return Error::Code::kInvalidIPV4Address; - - continue; - } - previous_dot = false; - if (!std::isdigit(c)) - return Error::Code::kInvalidIPV4Address; - - next_octet = next_octet * 10 + (c - '0'); - if (next_octet > 255) - return Error::Code::kInvalidIPV4Address; +// Returns the zero-expansion of a double-colon in |s| if |s| is a +// well-formatted IPv6 address. If |s| is ill-formatted, returns *any* string +// that is ill-formatted. +std::string ExpandIPv6DoubleColon(const std::string& s) { + constexpr char kDoubleColon[] = "::"; + const size_t double_colon_position = s.find(kDoubleColon); + if (double_colon_position == std::string::npos) { + return s; // Nothing to expand. + } + if (double_colon_position != s.rfind(kDoubleColon)) { + return {}; // More than one occurrence of double colons is illegal. } - if (previous_dot) - return Error::Code::kInvalidIPV4Address; - if (i != 3) - return Error::Code::kInvalidIPV4Address; + std::ostringstream expanded; + const int num_single_colons = std::count(s.begin(), s.end(), ':') - 2; + int num_zero_groups_to_insert = 8 - num_single_colons; + if (double_colon_position != 0) { + // abcd:0123:4567::f000:1 + // ^^^^^^^^^^^^^^^ + expanded << s.substr(0, double_colon_position + 1); + --num_zero_groups_to_insert; + } + if (double_colon_position != (s.size() - 2)) { + --num_zero_groups_to_insert; + } + while (--num_zero_groups_to_insert > 0) { + expanded << "0:"; + } + expanded << '0'; + if (double_colon_position != (s.size() - 2)) { + // abcd:0123:4567::f000:1 + // ^^^^^^^ + expanded << s.substr(double_colon_position + 1); + } + return expanded.str(); +} - address.bytes_[i] = static_cast<uint8_t>(next_octet); - address.version_ = Version::kV4; - return address; +ErrorOr<IPAddress> ParseV6(const std::string& s) { + const std::string scan_input = ExpandIPv6DoubleColon(s); + uint16_t hextets[8]; + int chars_scanned; + // Note: sscanf()'s parsing for %x allows leading whitespace; so the invalid + // presence of whitespace must be explicitly checked too. + if (std::any_of(s.begin(), s.end(), [](char c) { return std::isspace(c); }) || + sscanf(scan_input.c_str(), + "%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 + ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 "%n", + &hextets[0], &hextets[1], &hextets[2], &hextets[3], &hextets[4], + &hextets[5], &hextets[6], &hextets[7], &chars_scanned) != 8 || + chars_scanned != static_cast<int>(scan_input.size())) { + return Error::Code::kInvalidIPV6Address; + } + return IPAddress(hextets); } +} // namespace + // static -ErrorOr<IPAddress> IPAddress::ParseV6(const std::string& s) { - if (s.size() > 1 && s[0] == ':' && s[1] != ':') - return Error::Code::kInvalidIPV6Address; +ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) { + ErrorOr<IPAddress> v4 = ParseV4(s); - uint16_t next_value = 0; - uint8_t values[16]; - int i = 0; - int num_previous_colons = 0; - absl::optional<int> double_colon_index = absl::nullopt; - for (auto c : s) { - if (c == ':') { - ++num_previous_colons; - if (num_previous_colons == 2) { - if (double_colon_index) { - return Error::Code::kInvalidIPV6Address; - } - double_colon_index = i; - } else if (i >= 15 || num_previous_colons > 2) { - return Error::Code::kInvalidIPV6Address; - } else { - values[i++] = static_cast<uint8_t>(next_value >> 8); - values[i++] = static_cast<uint8_t>(next_value & 0xff); - next_value = 0; - } - } else { - num_previous_colons = 0; - uint8_t x = 0; - if (c >= '0' && c <= '9') { - x = c - '0'; - } else if (c >= 'a' && c <= 'f') { - x = c - 'a' + 10; - } else if (c >= 'A' && c <= 'F') { - x = c - 'A' + 10; - } else { - return Error::Code::kInvalidIPV6Address; - } - if (next_value & 0xf000) { - return Error::Code::kInvalidIPV6Address; - } else { - next_value = static_cast<uint16_t>(next_value * 16 + x); - } - } - } - if (num_previous_colons == 1) - return Error::Code::kInvalidIPV6Address; + return v4 ? std::move(v4) : ParseV6(s); +} - if (i >= 15) - return Error::Code::kInvalidIPV6Address; +IPEndpoint::operator bool() const { + return address || port; +} - values[i++] = static_cast<uint8_t>(next_value >> 8); - values[i] = static_cast<uint8_t>(next_value & 0xff); - if (!((i == 15 && !double_colon_index) || (i < 14 && double_colon_index))) { - return Error::Code::kInvalidIPV6Address; +// static +ErrorOr<IPEndpoint> IPEndpoint::Parse(const std::string& s) { + // Look for the colon that separates the IP address from the port number. Note + // that this check also guards against the case where |s| is the empty string. + const auto colon_pos = s.rfind(':'); + if (colon_pos == std::string::npos) { + return Error(Error::Code::kParseError, "missing colon separator"); + } + // The colon cannot be the first nor the last character in |s| because that + // would mean there is no address part or port part. + if (colon_pos == 0) { + return Error(Error::Code::kParseError, "missing address before colon"); + } + if (colon_pos == (s.size() - 1)) { + return Error(Error::Code::kParseError, "missing port after colon"); } - IPAddress address; - for (int j = 15; j >= 0;) { - if (double_colon_index && (i == double_colon_index)) { - address.bytes_[j--] = values[i--]; - while (j > i) - address.bytes_[j--] = 0; - } else { - address.bytes_[j--] = values[i--]; - } + ErrorOr<IPAddress> address(Error::Code::kParseError); + if (s[0] == '[' && s[colon_pos - 1] == ']') { + // [abcd:beef:1:1::2600]:8080 + // ^^^^^^^^^^^^^^^^^^^^^ + address = ParseV6(s.substr(1, colon_pos - 2)); + } else { + // 127.0.0.1:22 + // ^^^^^^^^^ + address = ParseV4(s.substr(0, colon_pos)); + } + if (address.is_error()) { + return Error(Error::Code::kParseError, "invalid address part"); + } + + const char* const port_part = s.c_str() + colon_pos + 1; + int port, chars_scanned; + // Note: sscanf()'s parsing for %d allows leading whitespace. Thus, if the + // first char is not whitespace, a successful sscanf() parse here can only + // mean numerical chars contributed to the parsed integer. + if (std::isspace(port_part[0]) || + sscanf(port_part, "%d%n", &port, &chars_scanned) != 1 || + port_part[chars_scanned] != '\0' || port < 0 || + port > std::numeric_limits<uint16_t>::max()) { + return Error(Error::Code::kParseError, "invalid port part"); } - address.version_ = Version::kV6; - return address; + + return IPEndpoint{address.value(), static_cast<uint16_t>(port)}; } bool operator==(const IPEndpoint& a, const IPEndpoint& b) { @@ -220,19 +269,23 @@ bool operator!=(const IPEndpoint& a, const IPEndpoint& b) { return !(a == b); } -bool IPEndpointComparator::operator()(const IPEndpoint& a, - const IPEndpoint& b) const { - if (a.address.version() != b.address.version()) - return a.address.version() < b.address.version(); - if (a.address.IsV4()) { - int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 4); - if (ret != 0) - return ret < 0; +bool IPAddress::operator<(const IPAddress& other) const { + if (version() != other.version()) { + return version() < other.version(); + } + + if (IsV4()) { + return memcmp(bytes_.data(), other.bytes_.data(), 4) < 0; } else { - int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 16); - if (ret != 0) - return ret < 0; + return memcmp(bytes_.data(), other.bytes_.data(), 16) < 0; } +} + +bool operator<(const IPEndpoint& a, const IPEndpoint& b) { + if (a.address != b.address) { + return a.address < b.address; + } + return a.port < b.port; } @@ -275,4 +328,10 @@ std::ostream& operator<<(std::ostream& out, const IPEndpoint& endpoint) { return out << ':' << std::dec << static_cast<int>(endpoint.port); } +std::string IPEndpoint::ToString() const { + std::ostringstream name; + name << this; + return name.str(); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.h b/chromium/third_party/openscreen/src/platform/base/ip_address.h index c4dff5b0c5d..c37054d9e83 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address.h +++ b/chromium/third_party/openscreen/src/platform/base/ip_address.h @@ -22,37 +22,35 @@ class IPAddress { kV6, }; + static const IPAddress kV4LoopbackAddress; + static const IPAddress kV6LoopbackAddress; + static constexpr size_t kV4Size = 4; static constexpr size_t kV6Size = 16; - // Parses a text representation of an IPv4 address (e.g. "192.168.0.1") or an - // IPv6 address (e.g. "abcd::1234") and puts the result into |address|. - static ErrorOr<IPAddress> Parse(const std::string& s); - IPAddress(); + + // |bytes| contains 4 octets for IPv4, or 8 hextets (16 bytes of big-endian + // shorts) for IPv6. + IPAddress(Version version, const uint8_t* bytes); + + // IPv4 constructors (IPAddress from 4 octets). explicit IPAddress(const std::array<uint8_t, 4>& bytes); explicit IPAddress(const uint8_t (&b)[4]); - explicit IPAddress(Version version, const uint8_t* b); IPAddress(uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4); - explicit IPAddress(const std::array<uint8_t, 16>& bytes); - explicit IPAddress(const uint8_t (&b)[16]); - IPAddress(uint8_t b1, - uint8_t b2, - uint8_t b3, - uint8_t b4, - uint8_t b5, - uint8_t b6, - uint8_t b7, - uint8_t b8, - uint8_t b9, - uint8_t b10, - uint8_t b11, - uint8_t b12, - uint8_t b13, - uint8_t b14, - uint8_t b15, - uint8_t b16); + // IPv6 constructors (IPAddress from 8 hextets). + explicit IPAddress(const std::array<uint16_t, 8>& hextets); + explicit IPAddress(const uint16_t (&hextets)[8]); + IPAddress(uint16_t h1, + uint16_t h2, + uint16_t h3, + uint16_t h4, + uint16_t h5, + uint16_t h6, + uint16_t h7, + uint16_t h8); + IPAddress(const IPAddress& o) noexcept; IPAddress(IPAddress&& o) noexcept; ~IPAddress() = default; @@ -62,6 +60,11 @@ class IPAddress { bool operator==(const IPAddress& o) const; bool operator!=(const IPAddress& o) const; + + bool operator<(const IPAddress& other) const; + bool operator>(const IPAddress& other) const { return other < *this; } + bool operator<=(const IPAddress& other) const { return !(other < *this); } + bool operator>=(const IPAddress& other) const { return !(*this < other); } explicit operator bool() const; Version version() const { return version_; } @@ -78,12 +81,11 @@ class IPAddress { // in order to avoid making multiple copies. const uint8_t* bytes() const { return bytes_.data(); } - private: - static ErrorOr<IPAddress> ParseV4(const std::string& s); - static ErrorOr<IPAddress> ParseV6(const std::string& s); - - friend class IPEndpointComparator; + // Parses a text representation of an IPv4 address (e.g. "192.168.0.1") or an + // IPv6 address (e.g. "abcd::1234"). + static ErrorOr<IPAddress> Parse(const std::string& s); + private: Version version_; std::array<uint8_t, 16> bytes_; }; @@ -91,16 +93,30 @@ class IPAddress { struct IPEndpoint { public: IPAddress address; - uint16_t port; + uint16_t port = 0; + + explicit operator bool() const; + + // Parses a text representation of an IPv4/IPv6 address and port (e.g. + // "192.168.0.1:8080" or "[abcd::1234]:8080"). + static ErrorOr<IPEndpoint> Parse(const std::string& s); + + std::string ToString() const; }; bool operator==(const IPEndpoint& a, const IPEndpoint& b); bool operator!=(const IPEndpoint& a, const IPEndpoint& b); -class IPEndpointComparator { - public: - bool operator()(const IPEndpoint& a, const IPEndpoint& b) const; -}; +bool operator<(const IPEndpoint& a, const IPEndpoint& b); +inline bool operator>(const IPEndpoint& a, const IPEndpoint& b) { + return b < a; +} +inline bool operator<=(const IPEndpoint& a, const IPEndpoint& b) { + return !(b > a); +} +inline bool operator>=(const IPEndpoint& a, const IPEndpoint& b) { + return !(a > b); +} // Outputs a string of the form: // 123.234.34.56 diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc index 1bbbe8ec0c7..03a3c7537e5 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc @@ -103,14 +103,16 @@ TEST(IPAddressTest, V4ParseFailures) { TEST(IPAddressTest, V6Constructors) { uint8_t bytes[16] = {}; - IPAddress address1(std::array<uint8_t, 16>{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}}); + IPAddress address1(std::array<uint16_t, 8>{ + {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}}); address1.CopyToV6(bytes); EXPECT_THAT(bytes, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})); const uint8_t x[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - IPAddress address2(x); + const uint16_t hextets[] = {0x0102, 0x0304, 0x0506, 0x0708, + 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; + IPAddress address2(hextets); address2.CopyToV6(bytes); EXPECT_THAT(bytes, ElementsAreArray(x)); @@ -118,7 +120,8 @@ TEST(IPAddressTest, V6Constructors) { address3.CopyToV6(bytes); EXPECT_THAT(bytes, ElementsAreArray(x)); - IPAddress address4(16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); + IPAddress address4(0x100f, 0x0e0d, 0x0c0b, 0x0a09, 0x0807, 0x0605, 0x0403, + 0x0201); address4.CopyToV6(bytes); EXPECT_THAT(bytes, ElementsAreArray({16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); @@ -135,11 +138,11 @@ TEST(IPAddressTest, V6ComparisonAndBoolean) { EXPECT_FALSE(address1); uint8_t x[] = {16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; - IPAddress address2(x); + IPAddress address2(IPAddress::Version::kV6, x); EXPECT_NE(address1, address2); EXPECT_TRUE(address2); - IPAddress address3(x); + IPAddress address3(IPAddress::Version::kV6, x); EXPECT_EQ(address2, address3); EXPECT_TRUE(address3); @@ -225,6 +228,12 @@ TEST(IPAddressTest, V6ParseFailures) { << "too few values should fail to parse"; EXPECT_FALSE(IPAddress::Parse("a:b:c:d:e:f:0:1:2:3:4:5:6:7:8:9:a")) << "too many values should fail to parse"; + EXPECT_FALSE(IPAddress::Parse("1:2:3:4:5:6:7::8")) + << "too many values around double-colon should fail to parse"; + EXPECT_FALSE(IPAddress::Parse("1:2:3:4:5:6:7:8::")) + << "too many values before double-colon should fail to parse"; + EXPECT_FALSE(IPAddress::Parse("::1:2:3:4:5:6:7:8")) + << "too many values after double-colon should fail to parse"; EXPECT_FALSE(IPAddress::Parse("abcd1::dbca")) << "value > 0xffff should fail to parse"; EXPECT_FALSE(IPAddress::Parse("::abcd::dbca")) @@ -248,4 +257,64 @@ TEST(IPAddressTest, V6ParseThreeDigitValue) { 0x01, 0x23})); } +TEST(IPAddressTest, IPEndpointBoolOperator) { + IPEndpoint endpoint; + if (endpoint) { + FAIL(); + } + + endpoint = IPEndpoint{{192, 168, 0, 1}, 80}; + if (!endpoint) { + FAIL(); + } + + endpoint = IPEndpoint{{192, 168, 0, 1}, 0}; + if (!endpoint) { + FAIL(); + } + + endpoint = IPEndpoint{{}, 80}; + if (!endpoint) { + FAIL(); + } +} + +TEST(IPAddressTest, IPEndpointParse) { + IPEndpoint expected{IPAddress(std::array<uint8_t, 4>{{1, 2, 3, 4}}), 5678}; + ErrorOr<IPEndpoint> result = IPEndpoint::Parse("1.2.3.4:5678"); + ASSERT_TRUE(result.is_value()) << result.error(); + EXPECT_EQ(expected, result.value()); + + expected = IPEndpoint{ + IPAddress(std::array<uint16_t, 8>{{0xabcd, 0, 0, 0, 0, 0, 0, 1}}), 99}; + result = IPEndpoint::Parse("[abcd::1]:99"); + ASSERT_TRUE(result.is_value()) << result.error(); + EXPECT_EQ(expected, result.value()); + + expected = IPEndpoint{ + IPAddress(std::array<uint16_t, 8>{{0, 0, 0, 0, 0, 0, 0, 0}}), 5791}; + result = IPEndpoint::Parse("[::]:5791"); + ASSERT_TRUE(result.is_value()) << result.error(); + EXPECT_EQ(expected, result.value()); + + EXPECT_FALSE(IPEndpoint::Parse("")); // Empty string. + EXPECT_FALSE(IPEndpoint::Parse("beef")); // Random word. + EXPECT_FALSE(IPEndpoint::Parse("localhost:99")); // We don't do DNS. + EXPECT_FALSE(IPEndpoint::Parse(":80")); // Missing address. + EXPECT_FALSE(IPEndpoint::Parse("[]:22")); // Missing address. + EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4")); // Missing port after IPv4. + EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]")); // Missing port after IPv6. + EXPECT_FALSE(IPEndpoint::Parse("abcd::1:8080")); // Missing square brackets. + + // No extra whitespace is allowed. + EXPECT_FALSE(IPEndpoint::Parse(" 1.2.3.4:5678")); + EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4 :5678")); + EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4: 5678")); + EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4:5678 ")); + EXPECT_FALSE(IPEndpoint::Parse(" [abcd::1]:99")); + EXPECT_FALSE(IPEndpoint::Parse("[abcd::1] :99")); + EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]: 99")); + EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]:99 ")); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/location.cc b/chromium/third_party/openscreen/src/platform/base/location.cc index f9ec010815a..ad5a3abd178 100644 --- a/chromium/third_party/openscreen/src/platform/base/location.cc +++ b/chromium/third_party/openscreen/src/platform/base/location.cc @@ -4,7 +4,8 @@ #include "platform/base/location.h" -#include "absl/strings/str_cat.h" +#include <sstream> + #include "platform/base/macros.h" namespace openscreen { @@ -24,7 +25,9 @@ std::string Location::ToString() const { return "pc:NULL"; } - return absl::StrCat("pc:0x", absl::Hex(program_counter_)); + std::ostringstream oss; + oss << "pc:" << program_counter_; + return oss.str(); } #if defined(__GNUC__) diff --git a/chromium/third_party/openscreen/src/platform/base/socket_state.h b/chromium/third_party/openscreen/src/platform/base/socket_state.h index bd82dd292d7..7672c4f5086 100644 --- a/chromium/third_party/openscreen/src/platform/base/socket_state.h +++ b/chromium/third_party/openscreen/src/platform/base/socket_state.h @@ -10,7 +10,6 @@ #include <string> namespace openscreen { -namespace platform { // SocketState should be used by TCP and TLS sockets for indicating // current state. NOTE: socket state transitions should only happen in @@ -30,7 +29,6 @@ enum class SocketState { kClosed }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_SOCKET_STATE_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h b/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h index 12a0cd269b0..dcaebe0ae1d 100644 --- a/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h +++ b/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h @@ -8,7 +8,6 @@ #include "platform/base/macros.h" namespace openscreen { -namespace platform { struct TlsConnectOptions { // This option allows TLS connections to devices without @@ -17,7 +16,6 @@ struct TlsConnectOptions { bool unsafely_skip_certificate_validation; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_TLS_CONNECT_OPTIONS_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc b/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc index 9493f597f86..7f7af41e37e 100644 --- a/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc +++ b/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc @@ -5,7 +5,6 @@ #include "platform/base/tls_credentials.h" namespace openscreen { -namespace platform { TlsCredentials::TlsCredentials() = default; @@ -18,5 +17,4 @@ TlsCredentials::TlsCredentials(std::vector<uint8_t> der_rsa_private_key, TlsCredentials::~TlsCredentials() = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/tls_credentials.h b/chromium/third_party/openscreen/src/platform/base/tls_credentials.h index 8c5162dc10f..285e27057cc 100644 --- a/chromium/third_party/openscreen/src/platform/base/tls_credentials.h +++ b/chromium/third_party/openscreen/src/platform/base/tls_credentials.h @@ -10,7 +10,6 @@ #include <vector> namespace openscreen { -namespace platform { struct TlsCredentials { TlsCredentials(); @@ -29,7 +28,6 @@ struct TlsCredentials { std::vector<uint8_t> der_x509_cert; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_TLS_CREDENTIALS_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h b/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h index 32cab160ba2..da47b661b58 100644 --- a/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h +++ b/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h @@ -10,13 +10,11 @@ #include "platform/base/macros.h" namespace openscreen { -namespace platform { struct TlsListenOptions { uint32_t backlog_size; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_TLS_LISTEN_OPTIONS_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc index dd7ac6c3443..16a893a0dd2 100644 --- a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc +++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc @@ -4,29 +4,69 @@ #include "platform/base/trace_logging_activation.h" +#include <atomic> #include <cassert> +#include <thread> namespace openscreen { -namespace platform { namespace { -TraceLoggingPlatform* g_current_destination = nullptr; -} // namespace -TraceLoggingPlatform* GetTracingDestination() { - return g_current_destination; +// If tracing is active, this is a valid pointer to an object that implements +// the TraceLoggingPlatform interface. If tracing is not active, this is +// nullptr. +std::atomic<TraceLoggingPlatform*> g_current_destination{}; + +// The count of threads currently calling into the current TraceLoggingPlatform. +std::atomic<int> g_use_count{}; + +inline TraceLoggingPlatform* PinCurrentDestination() { + // NOTE: It's important to increment the global use count *before* loading the + // pointer, to ensure the referent is pinned-down (i.e., any thread executing + // StopTracing() stays blocked) until CurrentTracingDestination's destructor + // calls UnpinCurrentDestination(). + g_use_count.fetch_add(1); + return g_current_destination.load(std::memory_order_relaxed); } +inline void UnpinCurrentDestination() { + g_use_count.fetch_sub(1); +} + +} // namespace + void StartTracing(TraceLoggingPlatform* destination) { - // TODO(crbug.com/openscreen/85): Need to revisit this to ensure thread-safety - // around the sequencing of starting and stopping tracing. - assert(!g_current_destination); - g_current_destination = destination; + assert(destination); + auto* const old_destination = g_current_destination.exchange(destination); + (void)old_destination; // Prevent "unused variable" compiler warnings. + assert(old_destination == nullptr || old_destination == destination); } void StopTracing() { - g_current_destination = nullptr; + auto* const old_destination = g_current_destination.exchange(nullptr); + if (!old_destination) { + return; // Already stopped. + } + + // Block the current thread until the global use count goes to zero. At that + // point, there can no longer be any dangling references. Theoretically, this + // loop may never terminate; but in practice, that should never happen. If it + // did happen, that would mean one or more CPU cores are continuously spending + // most of their time executing the TraceLoggingPlatform methods, yet those + // methods are supposed to be super-cheap and take near-zero time to execute! + int iters = 0; + while (g_use_count.load(std::memory_order_relaxed) != 0) { + assert(iters < 1024); + std::this_thread::yield(); + ++iters; + } +} + +CurrentTracingDestination::CurrentTracingDestination() + : destination_(PinCurrentDestination()) {} + +CurrentTracingDestination::~CurrentTracingDestination() { + UnpinCurrentDestination(); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h index 05be734387d..d53d9498fdc 100644 --- a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h +++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h @@ -6,20 +6,49 @@ #define PLATFORM_BASE_TRACE_LOGGING_ACTIVATION_H_ namespace openscreen { -namespace platform { class TraceLoggingPlatform; // Start or Stop trace logging. It is illegal to call StartTracing() a second // time without having called StopTracing() to stop the prior tracing session. +// +// Note that StopTracing() may block until all threads have returned from any +// in-progress calls into the TraceLoggingPlatform's methods. void StartTracing(TraceLoggingPlatform* destination); void StopTracing(); -// If tracing is active, returns the current destination. Otherwise, returns -// nullptr. -TraceLoggingPlatform* GetTracingDestination(); +// An immutable, non-copyable and non-movable smart pointer that references the +// current trace logging destination. If tracing was active when this class was +// intantiated, the pointer is valid for the life of the instance, and can be +// used to directly invoke the methods of the TraceLoggingPlatform API. If +// tracing was not active when this class was intantiated, the pointer is null +// for the life of the instance and must not be dereferenced. +// +// An instance should be short-lived, as a platform's call to StopTracing() will +// be blocked until there are no instances remaining. +// +// NOTE: This is generally not used directly, but instead via the +// util/trace_logging macros. +class CurrentTracingDestination { + public: + CurrentTracingDestination(); + ~CurrentTracingDestination(); + + explicit operator bool() const noexcept { return !!destination_; } + TraceLoggingPlatform* operator->() const noexcept { return destination_; } + + private: + CurrentTracingDestination(const CurrentTracingDestination&) = delete; + CurrentTracingDestination(CurrentTracingDestination&&) = delete; + CurrentTracingDestination& operator=(const CurrentTracingDestination&) = + delete; + CurrentTracingDestination& operator=(CurrentTracingDestination&&) = delete; + + // The destination at the time this class was constructed, and is valid for + // the lifetime of this class. This is nullptr if tracing was inactive. + TraceLoggingPlatform* const destination_; +}; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_TRACE_LOGGING_ACTIVATION_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h b/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h index ce33edb42da..0ee9dfb18a8 100644 --- a/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h +++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h @@ -5,10 +5,11 @@ #ifndef PLATFORM_BASE_TRACE_LOGGING_TYPES_H_ #define PLATFORM_BASE_TRACE_LOGGING_TYPES_H_ -#include "absl/types/optional.h" +#include <stdint.h> + +#include <limits> namespace openscreen { -namespace platform { // Define TraceId type here since other TraceLogging files import it. using TraceId = uint64_t; @@ -51,18 +52,17 @@ inline bool operator!=(const TraceIdHierarchy& lhs, // BitFlags to represent the supported tracing categories. // NOTE: These are currently placeholder values and later changes should feel // free to edit them. -// TODO(rwkeane): Rename SSL to either Ssl or kSsl struct TraceCategory { enum Value : uint64_t { - Any = std::numeric_limits<uint64_t>::max(), - mDNS = 0x01 << 0, - Quic = 0x01 << 1, - SSL = 0x01 << 2, - Presentation = 0x01 << 3, + kAny = std::numeric_limits<uint64_t>::max(), + kMdns = 0x01 << 0, + kQuic = 0x01 << 1, + kSsl = 0x01 << 2, + kPresentation = 0x01 << 3, + kStandaloneReceiver = 0x01 << 4 }; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_TRACE_LOGGING_TYPES_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc index 3be1c783268..96f1fa3dde0 100644 --- a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc +++ b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc @@ -7,15 +7,32 @@ namespace openscreen { std::ostream& operator<<(std::ostream& os, - const platform::TrivialClockTraits::duration& d) { - constexpr char kUnits[] = "\u03BCs"; // Greek Mu + "s" + const TrivialClockTraits::duration& d) { + constexpr char kUnits[] = " \u03BCs"; // Greek Mu + "s" return os << d.count() << kUnits; } std::ostream& operator<<(std::ostream& os, - const platform::TrivialClockTraits::time_point& tp) { - constexpr char kUnits[] = "\u03BCs-ticks"; // Greek Mu + "s-ticks" + const TrivialClockTraits::time_point& tp) { + constexpr char kUnits[] = " \u03BCs-ticks"; // Greek Mu + "s-ticks" return os << tp.time_since_epoch().count() << kUnits; } +std::ostream& operator<<(std::ostream& out, const std::chrono::hours& hrs) { + return (out << hrs.count() << " hours"); +} + +std::ostream& operator<<(std::ostream& out, const std::chrono::minutes& mins) { + return (out << mins.count() << " minutes"); +} + +std::ostream& operator<<(std::ostream& out, const std::chrono::seconds& secs) { + return (out << secs.count() << " seconds"); +} + +std::ostream& operator<<(std::ostream& out, + const std::chrono::milliseconds& millis) { + return (out << millis.count() << " ms"); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h index b5392c43400..426a2b8a9ce 100644 --- a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h +++ b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h @@ -9,7 +9,6 @@ #include <ostream> namespace openscreen { -namespace platform { // The Open Screen monotonic clock traits description, providing all the C++14 // requirements of a TrivialClock, for use with STL <chrono>. @@ -44,16 +43,22 @@ class TrivialClockTraits { // &Clock::now versus something else for testing). using ClockNowFunctionPtr = TrivialClockTraits::time_point (*)(); -} // namespace platform - // Logging convenience for durations. Outputs a string of the form "123µs". std::ostream& operator<<(std::ostream& os, - const platform::TrivialClockTraits::duration& d); + const TrivialClockTraits::duration& d); // Logging convenience for time points. Outputs a string of the form // "123µs-ticks". std::ostream& operator<<(std::ostream& os, - const platform::TrivialClockTraits::time_point& tp); + const TrivialClockTraits::time_point& tp); + +// Logging (and gtest pretty-printing) for several commonly-used chrono types. +std::ostream& operator<<(std::ostream& out, const std::chrono::hours&); +std::ostream& operator<<(std::ostream& out, const std::chrono::minutes&); +std::ostream& operator<<(std::ostream& out, const std::chrono::seconds&); +std::ostream& operator<<(std::ostream& out, const std::chrono::milliseconds&); +// Note: The ostream output operator for std::chrono::microseconds is handled by +// the one for TrivialClockTraits::duration above since they are the same type. } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc index d8ad75e421f..6cbb0da2f83 100644 --- a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc +++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc @@ -7,14 +7,9 @@ #include <cassert> namespace openscreen { -namespace platform { UdpPacket::UdpPacket() : std::vector<uint8_t>() {} -UdpPacket::UdpPacket(size_type size) : std::vector<uint8_t>(size) { - assert(size <= kUdpMaxPacketSize); -} - UdpPacket::UdpPacket(size_type size, uint8_t fill_value) : std::vector<uint8_t>(size, fill_value) { assert(size <= kUdpMaxPacketSize); @@ -31,5 +26,4 @@ UdpPacket::~UdpPacket() = default; UdpPacket& UdpPacket::operator=(UdpPacket&& other) = default; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.h b/chromium/third_party/openscreen/src/platform/base/udp_packet.h index 479cee594bf..a8fcce045a1 100644 --- a/chromium/third_party/openscreen/src/platform/base/udp_packet.h +++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.h @@ -7,14 +7,12 @@ #include <stdint.h> -#include <algorithm> #include <utility> #include <vector> #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { class UdpSocket; @@ -25,13 +23,9 @@ class UdpPacket : public std::vector<uint8_t> { public: // C++14 vector constructors, sans Allocator foo, and no copy ctor. UdpPacket(); - explicit UdpPacket(size_type size); - explicit UdpPacket(size_type size, uint8_t fill_value); + explicit UdpPacket(size_type size, uint8_t fill_value = {}); template <typename InputIt> - UdpPacket(InputIt first, InputIt last) - : UdpPacket(std::distance(first, last)) { - std::copy(first, last, begin()); - } + UdpPacket(InputIt first, InputIt last) : std::vector<uint8_t>(first, last) {} UdpPacket(UdpPacket&& other); UdpPacket(std::initializer_list<uint8_t> init); @@ -60,7 +54,6 @@ class UdpPacket : public std::vector<uint8_t> { OSP_DISALLOW_COPY_AND_ASSIGN(UdpPacket); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_BASE_UDP_PACKET_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/logging.h b/chromium/third_party/openscreen/src/platform/impl/logging.h index 033691f82b8..85d6abcd319 100644 --- a/chromium/third_party/openscreen/src/platform/impl/logging.h +++ b/chromium/third_party/openscreen/src/platform/impl/logging.h @@ -8,7 +8,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { // Direct all logging output to a named FIFO having the given |filename|. If the // file does not exist, an attempt is made to auto-create it. If unsuccessful, @@ -20,7 +19,6 @@ void SetLogFifoOrDie(const char* filename); // default. void SetLogLevel(LogLevel level); -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_LOGGING_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc b/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc index ab921168468..412876ce9bc 100644 --- a/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc @@ -16,7 +16,6 @@ #include "util/trace_logging.h" namespace openscreen { -namespace platform { namespace { int g_log_fd = STDERR_FILENO; @@ -80,22 +79,22 @@ void SetLogLevel(LogLevel level) { g_log_level = level; } -bool IsLoggingOn(LogLevel level, absl::string_view file) { +bool IsLoggingOn(LogLevel level, const char* file) { // Possible future enhancement: Use glob patterns passed on the command-line // to use a different logging level for certain files, like in Chromium. return level >= g_log_level; } void LogWithLevel(LogLevel level, - absl::string_view file, + const char* file, int line, - absl::string_view msg) { + std::stringstream message) { if (level < g_log_level) return; std::stringstream ss; ss << "[" << level << ":" << file << "(" << line << "):T" << std::hex - << TRACE_CURRENT_ID << "] " << msg << '\n'; + << TRACE_CURRENT_ID << "] " << message.rdbuf() << '\n'; const auto ss_str = ss.str(); const auto bytes_written = write(g_log_fd, ss_str.c_str(), ss_str.size()); OSP_DCHECK(bytes_written); @@ -109,5 +108,4 @@ void Break() { #endif } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc new file mode 100644 index 00000000000..17e240f2e45 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc @@ -0,0 +1,44 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "platform/impl/network_interface.h" + +namespace openscreen { + +std::vector<InterfaceInfo> GetNetworkInterfaces() { + std::vector<InterfaceInfo> interfaces = GetAllInterfaces(); + + const auto new_end = std::remove_if( + interfaces.begin(), interfaces.end(), [](const InterfaceInfo& info) { + return info.type != InterfaceInfo::Type::kEthernet && + info.type != InterfaceInfo::Type::kWifi && + info.type != InterfaceInfo::Type::kOther; + }); + interfaces.erase(new_end, interfaces.end()); + + return interfaces; +} + +// Returns an InterfaceInfo associated with the system's loopback interface. +absl::optional<InterfaceInfo> GetLoopbackInterfaceForTesting() { + const std::vector<InterfaceInfo> interfaces = GetAllInterfaces(); + auto it = std::find_if( + interfaces.begin(), interfaces.end(), [](const InterfaceInfo& info) { + return info.type == InterfaceInfo::Type::kLoopback && + std::find_if( + info.addresses.begin(), info.addresses.end(), + [](const IPSubnet& subnet) { + return subnet.address == IPAddress::kV4LoopbackAddress || + subnet.address == IPAddress::kV6LoopbackAddress; + }) != info.addresses.end(); + }); + + if (it == interfaces.end()) { + return absl::nullopt; + } else { + return *it; + } +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface.h b/chromium/third_party/openscreen/src/platform/impl/network_interface.h new file mode 100644 index 00000000000..8682cffb21a --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/network_interface.h @@ -0,0 +1,26 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef PLATFORM_IMPL_NETWORK_INTERFACE_H_ +#define PLATFORM_IMPL_NETWORK_INTERFACE_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "platform/base/interface_info.h" + +namespace openscreen { + +// The below functions are responsible for returning the network interfaces +// provided of the current machine. GetAllInterfaces() returns all interfaces, +// real or virtual. GetLoopbackInterfaceForTesting() returns one such interface +// which is associated with the machine's loopback interface, while +// GetNetworkInterfaces() returns all non-loopback interfaces. +std::vector<InterfaceInfo> GetAllInterfaces(); +absl::optional<InterfaceInfo> GetLoopbackInterfaceForTesting(); +std::vector<InterfaceInfo> GetNetworkInterfaces(); + +} // namespace openscreen + +#endif // PLATFORM_IMPL_NETWORK_INTERFACE_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc index 5f432265662..d58e5d9ce66 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc @@ -25,11 +25,11 @@ #include "absl/types/optional.h" #include "platform/api/network_interface.h" #include "platform/base/ip_address.h" +#include "platform/impl/network_interface.h" #include "platform/impl/scoped_pipe.h" #include "util/logging.h" namespace openscreen { -namespace platform { namespace { constexpr int kNetlinkRecvmsgBufSize = 8192; @@ -74,8 +74,9 @@ InterfaceInfo::Type GetInterfaceType(const std::string& ifname) { wr.ifr_name[IFNAMSIZ - 1] = 0; strncpy(ifr.ifr_name, ifname.c_str(), IFNAMSIZ - 1); ifr.ifr_data = &ecmd; - if (ioctl(s.get(), SIOCETHTOOL, &ifr) != -1) + if (ioctl(s.get(), SIOCETHTOOL, &ifr) != -1) { return InterfaceInfo::Type::kEthernet; + } return InterfaceInfo::Type::kOther; } @@ -86,6 +87,7 @@ InterfaceInfo::Type GetInterfaceType(const std::string& ifname) { // pointed to by |rta|. void GetInterfaceAttributes(struct rtattr* rta, unsigned int attrlen, + bool is_loopback, InterfaceInfo* info) { for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { if (rta->rta_type == IFLA_IFNAME) { @@ -93,12 +95,16 @@ void GetInterfaceAttributes(struct rtattr* rta, GetInterfaceName(reinterpret_cast<const char*>(RTA_DATA(rta))); } else if (rta->rta_type == IFLA_ADDRESS) { OSP_CHECK_EQ(sizeof(info->hardware_address), RTA_PAYLOAD(rta)); - std::memcpy(info->hardware_address, RTA_DATA(rta), + std::memcpy(info->hardware_address.data(), RTA_DATA(rta), sizeof(info->hardware_address)); } } - info->type = GetInterfaceType(info->name); + if (is_loopback) { + info->type = InterfaceInfo::Type::kLoopback; + } else { + info->type = GetInterfaceType(info->name); + } } // Reads the IPv4 or IPv6 address that comes from an RTM_NEWADDR message and @@ -220,16 +226,17 @@ std::vector<InterfaceInfo> GetLinkInfo() { struct ifinfomsg* interface_info = static_cast<struct ifinfomsg*>(NLMSG_DATA(netlink_header)); - // Only process non-loopback interfaces which are active (up). - if ((interface_info->ifi_flags & IFF_LOOPBACK) || - ((interface_info->ifi_flags & IFF_UP) == 0)) { + // Only process interfaces which are active (up). + if (!(interface_info->ifi_flags & IFF_UP)) { continue; } + info_list.emplace_back(); InterfaceInfo& info = info_list.back(); info.index = interface_info->ifi_index; GetInterfaceAttributes(IFLA_RTA(interface_info), - IFLA_PAYLOAD(netlink_header), &info); + IFLA_PAYLOAD(netlink_header), + interface_info->ifi_flags & IFF_LOOPBACK, &info); } } } @@ -351,11 +358,10 @@ void PopulateSubnetsOrClearList(std::vector<InterfaceInfo>* info_list) { } // namespace -std::vector<InterfaceInfo> GetNetworkInterfaces() { +std::vector<InterfaceInfo> GetAllInterfaces() { std::vector<InterfaceInfo> interfaces = GetLinkInfo(); PopulateSubnetsOrClearList(&interfaces); return interfaces; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc index 138d4918b57..da219bd7c32 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc @@ -19,11 +19,11 @@ #include "platform/api/network_interface.h" #include "platform/base/ip_address.h" +#include "platform/impl/network_interface.h" #include "platform/impl/scoped_pipe.h" #include "util/logging.h" namespace openscreen { -namespace platform { namespace { @@ -68,13 +68,12 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) { // Socket used for querying interface media types. const ScopedFd ioctl_socket(socket(AF_INET6, SOCK_DGRAM, 0)); - // Walk the |interfaces| linked list, creating the hierarchial structure. + // Walk the |interfaces| linked list, creating the hierarchical structure. std::vector<InterfaceInfo> results; for (ifaddrs* cur = interfaces; cur; cur = cur->ifa_next) { - // Skip: 1) loopback interfaces, 2) interfaces that are down, 3) interfaces - // with no address configured. - if ((IFF_LOOPBACK & cur->ifa_flags) || !(IFF_RUNNING & cur->ifa_flags) || - !cur->ifa_addr) { + // Skip: 1) interfaces that are down, 2) interfaces with no address + // configured. + if (!(IFF_RUNNING & cur->ifa_flags) || !cur->ifa_addr) { continue; } @@ -107,6 +106,9 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) { if (ifmr.ifm_current & IFM_ETHER) { type = InterfaceInfo::Type::kEthernet; } + if (cur->ifa_flags & IFF_LOOPBACK) { + type = InterfaceInfo::Type::kLoopback; + } // Start with an unknown hardware ethernet address, which should be // updated as the linked list is walked. @@ -163,7 +165,7 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) { } // namespace -std::vector<InterfaceInfo> GetNetworkInterfaces() { +std::vector<InterfaceInfo> GetAllInterfaces() { std::vector<InterfaceInfo> results; ifaddrs* interfaces; if (getifaddrs(&interfaces) == 0) { @@ -173,5 +175,4 @@ std::vector<InterfaceInfo> GetNetworkInterfaces() { return results; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc index f31f19aad79..d1415258bd7 100644 --- a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc @@ -4,48 +4,23 @@ #include "platform/impl/platform_client_posix.h" -#include <mutex> +#include <functional> +#include <vector> #include "platform/impl/udp_socket_reader_posix.h" namespace openscreen { -namespace platform { // static PlatformClientPosix* PlatformClientPosix::instance_ = nullptr; -PlatformClientPosix::PlatformClientPosix( - Clock::duration networking_operation_timeout, - Clock::duration networking_loop_interval) - : networking_loop_(networking_operations(), - networking_operation_timeout, - networking_loop_interval), - owned_task_runner_(Clock::now), - networking_loop_thread_(&OperationLoop::RunUntilStopped, - &networking_loop_), - task_runner_thread_(&TaskRunnerImpl::RunUntilStopped, - &owned_task_runner_.value()) {} - -PlatformClientPosix::PlatformClientPosix( - Clock::duration networking_operation_timeout, - Clock::duration networking_loop_interval, - std::unique_ptr<TaskRunner> task_runner) - : networking_loop_(networking_operations(), - networking_operation_timeout, - networking_loop_interval), - caller_provided_task_runner_(std::move(task_runner)), - networking_loop_thread_(&OperationLoop::RunUntilStopped, - &networking_loop_) {} - -PlatformClientPosix::~PlatformClientPosix() { - networking_loop_.RequestStopSoon(); - networking_loop_thread_.join(); - if (owned_task_runner_.has_value()) { - owned_task_runner_.value().RequestStopSoon(); - } - if (task_runner_thread_.joinable()) { - task_runner_thread_.join(); - } +// static +void PlatformClientPosix::Create(Clock::duration networking_operation_timeout, + Clock::duration networking_loop_interval, + std::unique_ptr<TaskRunnerImpl> task_runner) { + SetInstance(new PlatformClientPosix(networking_operation_timeout, + networking_loop_interval, + std::move(task_runner))); } // static @@ -56,21 +31,19 @@ void PlatformClientPosix::Create(Clock::duration networking_operation_timeout, } // static -void PlatformClientPosix::SetInstance(PlatformClientPosix* instance) { - OSP_DCHECK(!instance_); - instance_ = instance; -} - -// static void PlatformClientPosix::ShutDown() { OSP_DCHECK(instance_); delete instance_; instance_ = nullptr; } -TaskRunner* PlatformClientPosix::GetTaskRunner() { - return owned_task_runner_.has_value() ? &owned_task_runner_.value() - : caller_provided_task_runner_.get(); +TlsDataRouterPosix* PlatformClientPosix::tls_data_router() { + std::call_once(tls_data_router_initialization_, [this]() { + tls_data_router_ = + std::make_unique<TlsDataRouterPosix>(socket_handle_waiter()); + tls_data_router_created_.store(true); + }); + return tls_data_router_.get(); } UdpSocketReaderPosix* PlatformClientPosix::udp_socket_reader() { @@ -81,18 +54,56 @@ UdpSocketReaderPosix* PlatformClientPosix::udp_socket_reader() { return udp_socket_reader_.get(); } -TlsDataRouterPosix* PlatformClientPosix::tls_data_router() { - std::call_once(tls_data_router_initialization_, [this]() { - tls_data_router_ = - std::make_unique<TlsDataRouterPosix>(socket_handle_waiter()); - tls_data_router_created_.store(true); - }); - return tls_data_router_.get(); +TaskRunner* PlatformClientPosix::GetTaskRunner() { + return task_runner_.get(); } +PlatformClientPosix::~PlatformClientPosix() { + OSP_DVLOG << "Shutting down the Task Runner..."; + task_runner_->RequestStopSoon(); + if (task_runner_thread_ && task_runner_thread_->joinable()) { + task_runner_thread_->join(); + OSP_DVLOG << "\tTask Runner shutdown complete!"; + } + + OSP_DVLOG << "Shutting down network operations..."; + networking_loop_.RequestStopSoon(); + networking_loop_thread_.join(); + OSP_DVLOG << "\tNetwork operation shutdown complete!"; +} + +// static +void PlatformClientPosix::SetInstance(PlatformClientPosix* instance) { + OSP_DCHECK(!instance_); + instance_ = instance; +} + +PlatformClientPosix::PlatformClientPosix( + Clock::duration networking_operation_timeout, + Clock::duration networking_loop_interval) + : networking_loop_(networking_operations(), + networking_operation_timeout, + networking_loop_interval), + task_runner_(new TaskRunnerImpl(Clock::now)), + networking_loop_thread_(&OperationLoop::RunUntilStopped, + &networking_loop_), + task_runner_thread_( + std::thread(&TaskRunnerImpl::RunUntilStopped, task_runner_.get())) {} + +PlatformClientPosix::PlatformClientPosix( + Clock::duration networking_operation_timeout, + Clock::duration networking_loop_interval, + std::unique_ptr<TaskRunnerImpl> task_runner) + : networking_loop_(networking_operations(), + networking_operation_timeout, + networking_loop_interval), + task_runner_(std::move(task_runner)), + networking_loop_thread_(&OperationLoop::RunUntilStopped, + &networking_loop_) {} + SocketHandleWaiterPosix* PlatformClientPosix::socket_handle_waiter() { std::call_once(waiter_initialization_, [this]() { - waiter_ = std::make_unique<SocketHandleWaiterPosix>(); + waiter_ = std::make_unique<SocketHandleWaiterPosix>(&Clock::now); waiter_created_.store(true); }); return waiter_.get(); @@ -107,23 +118,11 @@ void PlatformClientPosix::PerformSocketHandleWaiterActions( socket_handle_waiter()->ProcessHandles(timeout); } -void PlatformClientPosix::PerformTlsDataRouterActions(Clock::duration timeout) { - if (!tls_data_router_created_.load()) { - return; - } - - tls_data_router()->PerformNetworkingOperations(timeout); -} - std::vector<std::function<void(Clock::duration)>> PlatformClientPosix::networking_operations() { return {[this](Clock::duration timeout) { - PerformSocketHandleWaiterActions(timeout); - }, - [this](Clock::duration timeout) { - PerformTlsDataRouterActions(timeout); - }}; + PerformSocketHandleWaiterActions(timeout); + }}; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h index 57b0cf75c02..6bda424b168 100644 --- a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h @@ -10,6 +10,7 @@ #include <mutex> #include <thread> +#include "absl/types/optional.h" #include "platform/api/time.h" #include "platform/base/macros.h" #include "platform/impl/socket_handle_waiter_posix.h" @@ -18,43 +19,57 @@ #include "util/operation_loop.h" namespace openscreen { -namespace platform { class UdpSocketReaderPosix; -// A PlatformClientPosix is an access point for all singletons in a standalone -// application. The static SetInstance method is to be called before library use -// begins, and the ShutDown() method should be called to deallocate the platform -// library's global singletons (for example to save memory when libcast is not -// in use). +// Creates and provides access to singletons used by the default platform +// implementation. An instance must be created before an application uses any +// public modules in the Open Screen Library. +// +// ShutDown() should be called to destroy the PlatformClientPosix's singletons +// and TaskRunner to save resources when library APIs are not in use. +// ShutDown() calls TaskRunner::RunUntilStopped() to run any pending cleanup +// tasks. +// +// Create and ShutDown must be called in the same sequence. +// +// FIXME: Remove Create and Shutdown and use the ctor/dtor directly. class PlatformClientPosix { public: - // This method is expected to be called before the library is used. - // The networking_loop_interval parameter here represents the minimum amount - // of time that should pass between iterations of the loop used to handle - // networking operations. Higher values will result in less time being spent - // on these operations, but also potentially less performant networking - // operations. The networking_operation_timeout parameter refers to how much - // time may be spent on a single networking operation type. - // NOTE: This method is NOT thread safe and should only be called from the - // embedder thread. + // Initializes the platform implementation. + // + // |networking_loop_interval| sets the minimum amount of time that should pass + // between iterations of the loop used to handle networking operations. Higher + // values will result in less time being spent on these operations, but also + // potentially less performant networking operations. + // + // |networking_operation_timeout| sets how much time may be spent on a + // single networking operation type. + // + // |task_runner| is a client-provided TaskRunner implementation. + static void Create(Clock::duration networking_operation_timeout, + Clock::duration networking_loop_interval, + std::unique_ptr<TaskRunnerImpl> task_runner); + + // Initializes the platform implementation and creates a new TaskRunner (which + // starts a new thread). static void Create(Clock::duration networking_operation_timeout, Clock::duration networking_loop_interval); // Shuts down and deletes the PlatformClient instance currently stored as a // singleton. This method is expected to be called before program exit. After // calling this method, if the client wishes to continue using the platform - // library, a new singleton must be created. - // NOTE: This method is NOT thread safe and should only be called from the - // embedder thread. + // library, Create() must be called again. static void ShutDown(); static PlatformClientPosix* GetInstance() { return instance_; } // This method is thread-safe. + // FIXME: Rename to GetTlsDataRouter() TlsDataRouterPosix* tls_data_router(); // This method is thread-safe. + // FIXME: Rename to GetUdpSocketReader() UdpSocketReaderPosix* udp_socket_reader(); // Returns the TaskRunner associated with this PlatformClient. @@ -62,29 +77,19 @@ class PlatformClientPosix { TaskRunner* GetTaskRunner(); protected: - // The TaskRunner parameter here is a user-provided TaskRunner to be used - // instead of the one normally created within PlatformClientPosix. Ownership - // of the TaskRunner is transferred to this class. - PlatformClientPosix(Clock::duration networking_operation_timeout, - Clock::duration networking_loop_interval, - std::unique_ptr<TaskRunner> task_runner); - - // This method is expected to be called in order to set the singleton instance - // (typically from the Create() method). It should only be called from the - // embedder thread. Client should be a new instance create via 'new' and - // ownership of this instance will be transferred to this class. - // NOTE: This method is NOT thread safe and should only be called from the - // embedder thread. - static void SetInstance(PlatformClientPosix* client); - // Called by ShutDown(). ~PlatformClientPosix(); + static void SetInstance(PlatformClientPosix* client); + private: - // Called by Create(). PlatformClientPosix(Clock::duration networking_operation_timeout, Clock::duration networking_loop_interval); + PlatformClientPosix(Clock::duration networking_operation_timeout, + Clock::duration networking_loop_interval, + std::unique_ptr<TaskRunnerImpl> task_runner); + // This method is thread-safe. SocketHandleWaiterPosix* socket_handle_waiter(); @@ -97,8 +102,8 @@ class PlatformClientPosix { // Instance objects with threads are created at object-creation time. // NOTE: Delayed instantiation of networking_loop_ may be useful in future. OperationLoop networking_loop_; - absl::optional<TaskRunnerImpl> owned_task_runner_; - std::unique_ptr<TaskRunner> caller_provided_task_runner_; + + std::unique_ptr<TaskRunnerImpl> task_runner_; // Track whether the associated instance variable has been created yet. std::atomic_bool waiter_created_{false}; @@ -118,14 +123,13 @@ class PlatformClientPosix { // Threads for running TaskRunner and OperationLoop instances. // NOTE: These must be declared last to avoid nondterministic failures. std::thread networking_loop_thread_; - std::thread task_runner_thread_; + absl::optional<std::thread> task_runner_thread_; static PlatformClientPosix* instance_; OSP_DISALLOW_COPY_AND_ASSIGN(PlatformClientPosix); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_PLATFORM_CLIENT_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc new file mode 100644 index 00000000000..c99a2dd928b --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc @@ -0,0 +1,56 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "platform/impl/scoped_wake_lock_linux.h" + +#include "platform/api/task_runner.h" +#include "platform/impl/platform_client_posix.h" +#include "util/logging.h" + +namespace openscreen { + +int ScopedWakeLockLinux::reference_count_ = 0; + +std::unique_ptr<ScopedWakeLock> ScopedWakeLock::Create() { + return std::make_unique<ScopedWakeLockLinux>(); +} + +namespace { + +TaskRunner* GetTaskRunner() { + auto* const platform_client = PlatformClientPosix::GetInstance(); + OSP_DCHECK(platform_client); + auto* const task_runner = platform_client->GetTaskRunner(); + OSP_DCHECK(task_runner); + return task_runner; +} + +} // namespace + +ScopedWakeLockLinux::ScopedWakeLockLinux() : ScopedWakeLock() { + OSP_DCHECK(GetTaskRunner()->IsRunningOnTaskRunner()); + if (reference_count_++ == 0) { + AcquireWakeLock(); + } +} + +ScopedWakeLockLinux::~ScopedWakeLockLinux() { + GetTaskRunner()->PostTask([] { + if (--reference_count_ == 0) { + ReleaseWakeLock(); + } + }); +} + +// static +void ScopedWakeLockLinux::AcquireWakeLock() { + OSP_UNIMPLEMENTED(); +} + +// static +void ScopedWakeLockLinux::ReleaseWakeLock() { + OSP_UNIMPLEMENTED(); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h new file mode 100644 index 00000000000..81117619ea4 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h @@ -0,0 +1,27 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_ +#define PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_ + +#include "platform/api/scoped_wake_lock.h" + +namespace openscreen { + +class ScopedWakeLockLinux : public ScopedWakeLock { + public: + ScopedWakeLockLinux(); + ~ScopedWakeLockLinux() override; + + private: + // TODO(jophba): implement linux wake lock. + static void AcquireWakeLock(); + static void ReleaseWakeLock(); + + static int reference_count_; +}; + +} // namespace openscreen + +#endif // PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc index 70b9e2ddc8d..3cd7f11ae8c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc @@ -11,7 +11,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { ScopedWakeLockMac::LockState ScopedWakeLockMac::lock_state_{}; @@ -70,5 +69,4 @@ void ScopedWakeLockMac::ReleaseWakeLock() { OSP_DCHECK_EQ(result, kIOReturnSuccess); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h index 4e7d2031266..1d1a2f416d9 100644 --- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h @@ -10,7 +10,6 @@ #include "platform/api/scoped_wake_lock.h" namespace openscreen { -namespace platform { class ScopedWakeLockMac : public ScopedWakeLock { public: @@ -29,7 +28,6 @@ class ScopedWakeLockMac : public ScopedWakeLock { static LockState lock_state_; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SCOPED_WAKE_LOCK_MAC_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc index d91975c6214..b45b4055873 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc @@ -4,24 +4,20 @@ #include "platform/impl/socket_address_posix.h" +#include <cstring> #include <vector> #include "util/logging.h" namespace openscreen { -namespace platform { SocketAddressPosix::SocketAddressPosix(const struct sockaddr& address) { if (address.sa_family == AF_INET) { memcpy(&internal_address_, &address, sizeof(struct sockaddr_in)); - endpoint_.address = IPAddress(IPAddress::Version::kV4, - reinterpret_cast<const uint8_t*>( - &internal_address_.v4.sin_addr.s_addr)); - endpoint_.port = ntohs(internal_address_.v4.sin_port); + RecomputeEndpoint(IPAddress::Version::kV4); } else if (address.sa_family == AF_INET6) { memcpy(&internal_address_, &address, sizeof(struct sockaddr_in6)); - endpoint_.address = IPAddress(internal_address_.v6.sin6_addr.s6_addr); - endpoint_.port = ntohs(internal_address_.v6.sin6_port); + RecomputeEndpoint(IPAddress::Version::kV6); } else { OSP_NOTREACHED() << "Unknown address type"; } @@ -55,7 +51,7 @@ struct sockaddr* SocketAddressPosix::address() { default: OSP_NOTREACHED(); return nullptr; - }; + } } const struct sockaddr* SocketAddressPosix::address() const { @@ -81,5 +77,25 @@ socklen_t SocketAddressPosix::size() const { return 0; } } -} // namespace platform + +void SocketAddressPosix::RecomputeEndpoint() { + RecomputeEndpoint(endpoint_.address.version()); +} + +void SocketAddressPosix::RecomputeEndpoint(IPAddress::Version version) { + switch (version) { + case IPAddress::Version::kV4: + endpoint_.address = IPAddress(IPAddress::Version::kV4, + reinterpret_cast<const uint8_t*>( + &internal_address_.v4.sin_addr.s_addr)); + endpoint_.port = ntohs(internal_address_.v4.sin_port); + break; + case IPAddress::Version::kV6: + endpoint_.address = IPAddress(IPAddress::Version::kV6, + internal_address_.v6.sin6_addr.s6_addr); + endpoint_.port = ntohs(internal_address_.v6.sin6_port); + break; + } +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h index 17f0e82a891..07e49940d86 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h @@ -17,7 +17,6 @@ #include "platform/base/ip_address.h" namespace openscreen { -namespace platform { class SocketAddressPosix { public: @@ -35,7 +34,13 @@ class SocketAddressPosix { IPAddress::Version version() const { return endpoint_.address.version(); } IPEndpoint endpoint() const { return endpoint_; } + // Recomputes |endpoint_| if |internal_address_| is written to directly, e.g. + // by a system call. + void RecomputeEndpoint(); + private: + void RecomputeEndpoint(IPAddress::Version version); + // The way the sockaddr_* family works in POSIX is pretty unintuitive. The // sockaddr_in and sockaddr_in6 structs can be reinterpreted as type // sockaddr, however they don't have a common parent--the types are unrelated. @@ -50,7 +55,6 @@ class SocketAddressPosix { IPEndpoint endpoint_; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SOCKET_ADDRESS_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc index 7d79c0c8f04..e8b81c0962a 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc @@ -8,7 +8,6 @@ #include "gtest/gtest.h" namespace openscreen { -namespace platform { TEST(SocketAddressPosixTest, IPv4SocketAddressConvertsSuccessfully) { const SocketAddressPosix address(IPEndpoint{{1, 2, 3, 4}, 80}); @@ -26,8 +25,8 @@ TEST(SocketAddressPosixTest, IPv4SocketAddressConvertsSuccessfully) { } TEST(SocketAddressPosixTest, IPv6SocketAddressConvertsSuccessfully) { - const SocketAddressPosix address( - IPEndpoint{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 80}); + const SocketAddressPosix address(IPEndpoint{ + {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}, 80}); const sockaddr_in6* v6_address = reinterpret_cast<const sockaddr_in6*>(address.address()); @@ -85,9 +84,8 @@ TEST(SocketAddressPosixTest, IPv6ConvertsSuccessfully) { EXPECT_THAT(v6_address->sin6_addr.s6_addr, testing::ElementsAreArray(kExpectedAddress)); IPEndpoint expected_address{ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 80}; + {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}, 80}; EXPECT_EQ(address_posix.endpoint(), expected_address); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle.h index 6a775f1a1bf..fca8192e941 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle.h @@ -8,7 +8,6 @@ #include <cstdlib> namespace openscreen { -namespace platform { // A SocketHandle is the handle used to access a Socket by the underlying // platform. @@ -23,7 +22,6 @@ inline bool operator!=(const SocketHandle& lhs, const SocketHandle& rhs) { return !(lhs == rhs); } -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SOCKET_HANDLE_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc index 7d99ca32ab3..dfc9c5bb34c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc @@ -8,7 +8,6 @@ #include <functional> namespace openscreen { -namespace platform { SocketHandle::SocketHandle(int descriptor) : fd(descriptor) {} @@ -20,5 +19,4 @@ size_t SocketHandleHash::operator()(const SocketHandle& handle) const { return std::hash<int>()(handle.fd); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h index 59ff670d913..13f977c33d8 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h @@ -8,14 +8,12 @@ #include "platform/impl/socket_handle.h" namespace openscreen { -namespace platform { struct SocketHandle { explicit SocketHandle(int descriptor); int fd; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SOCKET_HANDLE_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc index e3731b34850..1560aa6c6dd 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc @@ -11,7 +11,9 @@ #include "util/logging.h" namespace openscreen { -namespace platform { + +SocketHandleWaiter::SocketHandleWaiter(ClockNowFunctionPtr now_function) + : now_function_(now_function) {} void SocketHandleWaiter::Subscribe(Subscriber* subscriber, SocketHandleRef handle) { @@ -51,6 +53,7 @@ void SocketHandleWaiter::OnHandleDeletion(Subscriber* subscriber, if (!disable_locking_for_testing) { handles_being_deleted_.push_back(handle); + OSP_DVLOG << "Starting to block for handle deletion"; // This code will allow us to block completion of the socket destructor // (and subsequent invalidation of pointers to this socket) until we no // longer are waiting on a SELECT(...) call to it, since we only signal @@ -60,25 +63,36 @@ void SocketHandleWaiter::OnHandleDeletion(Subscriber* subscriber, handles_being_deleted_.end(), handle) == handles_being_deleted_.end(); }); + OSP_DVLOG << "\tDone blocking for handle deletion!"; } } } void SocketHandleWaiter::ProcessReadyHandles( - const std::vector<SocketHandleRef>& handles) { - std::lock_guard<std::mutex> lock(mutex_); - for (const SocketHandleRef& handle : handles) { - auto iterator = handle_mappings_.find(handle); - if (iterator == handle_mappings_.end()) { - // This is OK: SocketHandle was deleted in the meantime. - continue; + const std::vector<HandleWithSubscriber>& handles, + Clock::duration timeout) { + Clock::time_point start_time = now_function_(); + bool processed_one = false; + // TODO(btolsch): Track explicit or implicit time since last handled on each + // watched handle so we can sort by it here for better fairness. + for (const HandleWithSubscriber& handle : handles) { + Clock::time_point current_time = now_function_(); + if (processed_one && (current_time - start_time) > timeout) { + return; } - iterator->second->ProcessReadyHandle(handle); + processed_one = true; + handle.subscriber->ProcessReadyHandle(handle.handle); + + current_time = now_function_(); + if ((current_time - start_time) > timeout) { + return; + } } } Error SocketHandleWaiter::ProcessHandles(Clock::duration timeout) { + Clock::time_point start_time = now_function_(); std::vector<SocketHandleRef> handles; { std::lock_guard<std::mutex> lock(mutex_); @@ -90,22 +104,37 @@ Error SocketHandleWaiter::ProcessHandles(Clock::duration timeout) { } } + Clock::time_point current_time = now_function_(); + Clock::duration remaining_timeout = timeout - (current_time - start_time); ErrorOr<std::vector<SocketHandleRef>> changed_handles = - AwaitSocketsReadable(handles, timeout); + AwaitSocketsReadable(handles, remaining_timeout); + std::vector<HandleWithSubscriber> ready_handles; { std::lock_guard<std::mutex> lock(mutex_); handles_being_deleted_.clear(); handle_deletion_block_.notify_all(); - } + if (changed_handles) { + auto& ch = changed_handles.value(); + ready_handles.reserve(ch.size()); + for (const auto& handle : ch) { + auto mapping_it = handle_mappings_.find(handle); + if (mapping_it != handle_mappings_.end()) { + ready_handles.push_back( + HandleWithSubscriber{handle, mapping_it->second}); + } + } + } - if (changed_handles.is_error()) { - return changed_handles.error(); - } + if (changed_handles.is_error()) { + return changed_handles.error(); + } - ProcessReadyHandles(changed_handles.value()); + current_time = now_function_(); + remaining_timeout = timeout - (current_time - start_time); + ProcessReadyHandles(ready_handles, remaining_timeout); + } return Error::None(); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h index 1b97abd4f7e..1ad81526058 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h @@ -18,7 +18,6 @@ #include "platform/impl/socket_handle.h" namespace openscreen { -namespace platform { // The class responsible for calling platform-level method to watch UDP sockets // for available read data. Reading from these sockets is handled at a higher @@ -36,7 +35,7 @@ class SocketHandleWaiter { virtual void ProcessReadyHandle(SocketHandleRef handle) = 0; }; - SocketHandleWaiter() = default; + explicit SocketHandleWaiter(ClockNowFunctionPtr now_function); virtual ~SocketHandleWaiter() = default; // Start notifying |subscriber| whenever |handle| has an event. May be called @@ -73,8 +72,15 @@ class SocketHandleWaiter { const Clock::duration& timeout) = 0; private: - // Call the subscriber associated with each changed handle. - void ProcessReadyHandles(const std::vector<SocketHandleRef>& handles); + struct HandleWithSubscriber { + SocketHandleRef handle; + Subscriber* subscriber; + }; + + // Call the subscriber associated with each changed handle. Handles are only + // processed until |timeout| is exceeded. Must be called with |mutex_| held. + void ProcessReadyHandles(const std::vector<HandleWithSubscriber>& handles, + Clock::duration timeout); // Guards against concurrent access to all other class data members. std::mutex mutex_; @@ -90,9 +96,10 @@ class SocketHandleWaiter { // that is watching them. std::unordered_map<SocketHandleRef, Subscriber*, SocketHandleHash> handle_mappings_; + + const ClockNowFunctionPtr now_function_; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SOCKET_HANDLE_WAITER_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc index 0cfca99bdf6..38762bf618d 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc @@ -16,9 +16,10 @@ #include "util/logging.h" namespace openscreen { -namespace platform { -SocketHandleWaiterPosix::SocketHandleWaiterPosix() = default; +SocketHandleWaiterPosix::SocketHandleWaiterPosix( + ClockNowFunctionPtr now_function) + : SocketHandleWaiter(now_function) {} SocketHandleWaiterPosix::~SocketHandleWaiterPosix() = default; @@ -27,9 +28,14 @@ SocketHandleWaiterPosix::AwaitSocketsReadable( const std::vector<SocketHandleRef>& socket_handles, const Clock::duration& timeout) { int max_fd = -1; - FD_ZERO(&read_handles_); + fd_set read_handles; + fd_set write_handles; + + FD_ZERO(&read_handles); + FD_ZERO(&write_handles); for (const SocketHandle& handle : socket_handles) { - FD_SET(handle.fd, &read_handles_); + FD_SET(handle.fd, &read_handles); + FD_SET(handle.fd, &write_handles); max_fd = std::max(max_fd, handle.fd); } if (max_fd < 0) { @@ -37,10 +43,13 @@ SocketHandleWaiterPosix::AwaitSocketsReadable( } struct timeval tv = ToTimeval(timeout); - // This value is set to 'max_fd + 1' by convention. For more information, see: + // This value is set to 'max_fd + 1' by convention. Also, select() is + // level-triggered so incomplete reads/writes by the caller are fine and will + // be picked up again on the next select() call. For more information, see: // http://man7.org/linux/man-pages/man2/select.2.html int max_fd_to_watch = max_fd + 1; - const int rv = select(max_fd_to_watch, &read_handles_, nullptr, nullptr, &tv); + const int rv = + select(max_fd_to_watch, &read_handles, &write_handles, nullptr, &tv); if (rv == -1) { // This is the case when an error condition is hit within the select(...) // command. @@ -52,7 +61,9 @@ SocketHandleWaiterPosix::AwaitSocketsReadable( std::vector<SocketHandleRef> changed_handles; for (const SocketHandleRef& handle : socket_handles) { - if (FD_ISSET(handle.get().fd, &read_handles_)) { + if (FD_ISSET(handle.get().fd, &read_handles) || + FD_ISSET(handle.get().fd, &write_handles)) { + // TODO(btolsch): Distinguish between read and write. changed_handles.push_back(handle); } } @@ -74,5 +85,4 @@ void SocketHandleWaiterPosix::RequestStopSoon() { is_running_.store(false); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h index 2a8b868300a..e2b78e5d8e0 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h @@ -13,13 +13,12 @@ #include "platform/impl/socket_handle_waiter.h" namespace openscreen { -namespace platform { class SocketHandleWaiterPosix : public SocketHandleWaiter { public: using SocketHandleRef = SocketHandleWaiter::SocketHandleRef; - SocketHandleWaiterPosix(); + explicit SocketHandleWaiterPosix(ClockNowFunctionPtr now_function); ~SocketHandleWaiterPosix() override; // Runs the Wait function in a loop until the below RequestStopSoon function @@ -35,13 +34,10 @@ class SocketHandleWaiterPosix : public SocketHandleWaiter { const Clock::duration& timeout) override; private: - fd_set read_handles_; - // Atomic so that we can perform atomic exchanges. std::atomic_bool is_running_; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_SOCKET_HANDLE_WAITER_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc index fcc249e53ee..bcd647070bc 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc @@ -14,9 +14,9 @@ #include "gtest/gtest.h" #include "platform/impl/socket_handle_posix.h" #include "platform/impl/timeval_posix.h" +#include "platform/test/fake_clock.h" namespace openscreen { -namespace platform { namespace { using namespace ::testing; @@ -32,10 +32,14 @@ class TestingSocketHandleWaiter : public SocketHandleWaiter { public: using SocketHandleRef = SocketHandleWaiter::SocketHandleRef; + TestingSocketHandleWaiter() : SocketHandleWaiter(&FakeClock::now) {} + MOCK_METHOD2( AwaitSocketsReadable, ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&, const Clock::duration&)); + + FakeClock fake_clock{Clock::time_point{Clock::duration{1234567}}}; }; } // namespace @@ -88,5 +92,4 @@ TEST(SocketHandleWaiterTest, WatchedSocketsReturnedToCorrectSubscribers) { waiter.ProcessHandles(Clock::duration{0}); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket.h b/chromium/third_party/openscreen/src/platform/impl/stream_socket.h index 7278e05b8dc..b4536e274bd 100644 --- a/chromium/third_party/openscreen/src/platform/impl/stream_socket.h +++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket.h @@ -17,7 +17,6 @@ #include "platform/impl/socket_handle.h" namespace openscreen { -namespace platform { // StreamSocket is an incomplete abstraction of synchronous platform methods for // creating, initializing, and closing stream sockets. Callers can use this @@ -68,7 +67,6 @@ class StreamSocket { virtual IPAddress::Version version() const = 0; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_STREAM_SOCKET_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc index b60e82ae0d9..c753344dd9b 100644 --- a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc @@ -7,12 +7,12 @@ #include <fcntl.h> #include <netinet/in.h> #include <netinet/ip.h> +#include <string.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> namespace openscreen { -namespace platform { namespace { constexpr int kDefaultMaxBacklogSize = 64; @@ -38,10 +38,15 @@ StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint) local_address_(local_endpoint) {} StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address, + IPEndpoint remote_address, int file_descriptor) : handle_(file_descriptor), version_(local_address.version()), - local_address_(local_address) {} + local_address_(local_address), + remote_address_(remote_address), + state_(SocketState::kConnected) { + EnsureInitialized(); +} StreamSocketPosix::~StreamSocketPosix() { if (state_ == SocketState::kConnected) { @@ -74,11 +79,14 @@ ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() { const int new_file_descriptor = accept(handle_.fd, new_remote_address.address(), &remote_address_size); if (new_file_descriptor == kUnsetHandleFd) { - return CloseOnError(Error::Code::kSocketAcceptFailure); + return CloseOnError( + Error(Error::Code::kSocketAcceptFailure, strerror(errno))); } + new_remote_address.RecomputeEndpoint(); return ErrorOr<std::unique_ptr<StreamSocket>>( - std::make_unique<StreamSocketPosix>(new_remote_address, + std::make_unique<StreamSocketPosix>(local_address_.value(), + new_remote_address.endpoint(), new_file_descriptor)); } @@ -97,7 +105,8 @@ Error StreamSocketPosix::Bind() { if (bind(handle_.fd, local_address_.value().address(), local_address_.value().size()) != 0) { - return CloseOnError(Error::Code::kSocketBindFailure); + return CloseOnError( + Error(Error::Code::kSocketBindFailure, strerror(errno))); } is_bound_ = true; @@ -134,8 +143,10 @@ Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) { } SocketAddressPosix address(remote_endpoint); - if (connect(handle_.fd, address.address(), address.size()) != 0) { - return CloseOnError(Error::Code::kSocketConnectFailure); + int ret = connect(handle_.fd, address.address(), address.size()); + if (ret != 0 && errno != EINPROGRESS) { + return CloseOnError( + Error(Error::Code::kSocketConnectFailure, strerror(errno))); } if (!is_bound_) { @@ -143,13 +154,14 @@ Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) { return CloseOnError(Error::Code::kSocketInvalidState); } - struct sockaddr address; + struct sockaddr_in6 address; socklen_t size = sizeof(address); - if (getsockname(handle_.fd, &address, &size) != 0) { + if (getsockname(handle_.fd, reinterpret_cast<struct sockaddr*>(&address), + &size) != 0) { return CloseOnError(Error::Code::kSocketConnectFailure); } - local_address_.emplace(address); + local_address_.emplace(reinterpret_cast<struct sockaddr&>(address)); is_bound_ = true; } @@ -168,7 +180,8 @@ Error StreamSocketPosix::Listen(int max_backlog_size) { } if (listen(handle_.fd, max_backlog_size) != 0) { - return CloseOnError(Error::Code::kSocketListenFailure); + return CloseOnError( + Error(Error::Code::kSocketListenFailure, strerror(errno))); } return Error::None(); @@ -201,7 +214,7 @@ bool StreamSocketPosix::EnsureInitialized() { return Initialize() == Error::None(); } - return false; + return handle_.fd != kUnsetHandleFd && is_initialized_; } Error StreamSocketPosix::Initialize() { @@ -209,40 +222,43 @@ Error StreamSocketPosix::Initialize() { return Error::Code::kItemAlreadyExists; } - int domain; - switch (version_) { - case IPAddress::Version::kV4: - domain = AF_INET; - break; - case IPAddress::Version::kV6: - domain = AF_INET6; - break; - } + int fd = handle_.fd; + if (fd == kUnsetHandleFd) { + int domain; + switch (version_) { + case IPAddress::Version::kV4: + domain = AF_INET; + break; + case IPAddress::Version::kV6: + domain = AF_INET6; + break; + } - const int file_descriptor = socket(domain, SOCK_STREAM, 0); - if (file_descriptor == kUnsetHandleFd) { - last_error_code_ = Error::Code::kSocketInvalidState; - return Error::Code::kSocketInvalidState; + fd = socket(domain, SOCK_STREAM, 0); + if (fd == kUnsetHandleFd) { + last_error_code_ = Error::Code::kSocketInvalidState; + return Error::Code::kSocketInvalidState; + } } - const int current_flags = fcntl(file_descriptor, F_GETFL, 0); - if (fcntl(file_descriptor, F_SETFL, current_flags | O_NONBLOCK) == -1) { - close(file_descriptor); + const int current_flags = fcntl(fd, F_GETFL, 0); + if (fcntl(fd, F_SETFL, current_flags | O_NONBLOCK) == -1) { + close(fd); last_error_code_ = Error::Code::kSocketInvalidState; return Error::Code::kSocketInvalidState; } - handle_.fd = file_descriptor; + handle_.fd = fd; is_initialized_ = true; // last_error_code_ should still be Error::None(). return Error::None(); } -Error StreamSocketPosix::CloseOnError(Error::Code error_code) { - last_error_code_ = error_code; +Error StreamSocketPosix::CloseOnError(Error error) { + last_error_code_ = error.code(); Close(); state_ = SocketState::kClosed; - return error_code; + return error; } // If is_open is false, the socket has either not been initialized @@ -251,5 +267,4 @@ Error StreamSocketPosix::ReportSocketClosedError() { last_error_code_ = Error::Code::kSocketClosedFailure; return Error::Code::kSocketClosedFailure; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h index a990ca2b2d6..93f81744329 100644 --- a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h @@ -15,16 +15,17 @@ #include "platform/impl/socket_address_posix.h" #include "platform/impl/socket_handle_posix.h" #include "platform/impl/stream_socket.h" -#include "platform/impl/weak_ptr.h" +#include "util/weak_ptr.h" namespace openscreen { -namespace platform { class StreamSocketPosix : public StreamSocket { public: StreamSocketPosix(IPAddress::Version version); StreamSocketPosix(const IPEndpoint& local_endpoint); - StreamSocketPosix(SocketAddressPosix local_address, int file_descriptor); + StreamSocketPosix(SocketAddressPosix local_address, + IPEndpoint remote_address, + int file_descriptor); // StreamSocketPosix is non-copyable, due to directly managing the file // descriptor. @@ -58,7 +59,7 @@ class StreamSocketPosix : public StreamSocket { bool EnsureInitialized(); Error Initialize(); - Error CloseOnError(Error::Code error_code); + Error CloseOnError(Error error); Error ReportSocketClosedError(); constexpr static int kUnsetHandleFd = -1; @@ -81,7 +82,6 @@ class StreamSocketPosix : public StreamSocket { WeakPtrFactory<StreamSocketPosix> weak_factory_{this}; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_STREAM_SOCKET_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc index 8ff92b34b2e..ac4a49a86f1 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc @@ -4,14 +4,33 @@ #include "platform/impl/task_runner.h" +#include <csignal> #include <thread> #include "util/logging.h" namespace openscreen { -namespace platform { -TaskRunnerImpl::TaskRunnerImpl(platform::ClockNowFunctionPtr now_function, +namespace { + +// This is mutated by the signal handler installed by RunUntilSignaled(), and is +// checked by RunUntilStopped(). +// +// Per the C++14 spec, passing visible changes to memory between a signal +// handler and a program thread must be done through a volatile variable. +volatile enum { + kNotRunning, + kNotSignaled, + kSignaled +} g_signal_state = kNotRunning; + +void OnReceivedSignal(int signal) { + g_signal_state = kSignaled; +} + +} // namespace + +TaskRunnerImpl::TaskRunnerImpl(ClockNowFunctionPtr now_function, TaskWaiter* event_waiter, Clock::duration waiter_timeout) : now_function_(now_function), @@ -37,8 +56,12 @@ void TaskRunnerImpl::PostPackagedTask(Task task) { void TaskRunnerImpl::PostPackagedTaskWithDelay(Task task, Clock::duration delay) { std::lock_guard<std::mutex> lock(task_mutex_); - delayed_tasks_.emplace( - std::make_pair(now_function_() + delay, std::move(task))); + if (delay <= Clock::duration::zero()) { + tasks_.emplace_back(std::move(task)); + } else { + delayed_tasks_.emplace( + std::make_pair(now_function_() + delay, std::move(task))); + } if (task_waiter_) { task_waiter_->OnTaskPosted(); } else { @@ -56,12 +79,16 @@ void TaskRunnerImpl::RunUntilStopped() { is_running_ = true; // Main loop: Run until the |is_running_| flag is set back to false by the - // "quit task" posted by RequestStopSoon(). + // "quit task" posted by RequestStopSoon(), or the process received a + // termination signal. while (is_running_) { ScheduleDelayedTasks(); if (GrabMoreRunnableTasks()) { RunRunnableTasks(); } + if (g_signal_state == kSignaled) { + is_running_ = false; + } } // Flushing phase: Ensure all immediately-runnable tasks are run before @@ -80,6 +107,20 @@ void TaskRunnerImpl::RunUntilStopped() { task_runner_thread_id_ = std::thread::id(); } +void TaskRunnerImpl::RunUntilSignaled() { + OSP_CHECK_EQ(g_signal_state, kNotRunning) + << __func__ << " may not be invoked concurrently."; + g_signal_state = kNotSignaled; + const auto old_sigint_handler = std::signal(SIGINT, &OnReceivedSignal); + const auto old_sigterm_handler = std::signal(SIGTERM, &OnReceivedSignal); + + RunUntilStopped(); + + std::signal(SIGINT, old_sigint_handler); + std::signal(SIGTERM, old_sigterm_handler); + g_signal_state = kNotRunning; +} + void TaskRunnerImpl::RequestStopSoon() { PostTask([this]() { is_running_ = false; }); } @@ -143,5 +184,4 @@ bool TaskRunnerImpl::GrabMoreRunnableTasks() { return false; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.h b/chromium/third_party/openscreen/src/platform/impl/task_runner.h index 35da181b9ae..629e23cb7b2 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner.h +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.h @@ -22,7 +22,6 @@ #include "util/trace_logging.h" namespace openscreen { -namespace platform { class TaskRunnerImpl final : public TaskRunner { public: @@ -49,7 +48,7 @@ class TaskRunnerImpl final : public TaskRunner { }; explicit TaskRunnerImpl( - platform::ClockNowFunctionPtr now_function, + ClockNowFunctionPtr now_function, TaskWaiter* event_waiter = nullptr, Clock::duration waiter_timeout = std::chrono::milliseconds(100)); @@ -64,6 +63,11 @@ class TaskRunnerImpl final : public TaskRunner { // called. void RunUntilStopped(); + // Blocks the current thread, executing tasks from the queue with the desired + // timing; and does not return until some time after the current process is + // signaled with SIGINT or SIGTERM, or after RequestStopSoon() is called. + void RunUntilSignaled(); + // Thread-safe method for requesting the TaskRunner to stop running after all // non-delayed tasks in the queue have run. This behavior allows final // clean-up tasks to be executed before the TaskRunner stops. @@ -83,7 +87,7 @@ class TaskRunnerImpl final : public TaskRunner { // used. This simplifies switching between 'Task' and 'TaskWithMetadata' // based on the compilation flag. TaskWithMetadata(Task task) - : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY){}; + : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY) {} void operator()() { TRACE_SET_HIERARCHY(trace_ids_); @@ -111,7 +115,7 @@ class TaskRunnerImpl final : public TaskRunner { // transferred. bool GrabMoreRunnableTasks(); - const platform::ClockNowFunctionPtr now_function_; + const ClockNowFunctionPtr now_function_; // Flag that indicates whether the task runner loop should continue. This is // only meant to be read/written on the thread executing RunUntilStopped(). @@ -141,7 +145,6 @@ class TaskRunnerImpl final : public TaskRunner { OSP_DISALLOW_COPY_AND_ASSIGN(TaskRunnerImpl); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TASK_RUNNER_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc index ac4331964fc..d7535cf34b2 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc @@ -13,7 +13,6 @@ #include "platform/test/fake_clock.h" namespace openscreen { -namespace platform { namespace { using namespace ::testing; @@ -31,7 +30,7 @@ void WaitUntilCondition(std::function<bool()> predicate) { class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter { public: - explicit FakeTaskWaiter(platform::ClockNowFunctionPtr now_function) + explicit FakeTaskWaiter(ClockNowFunctionPtr now_function) : now_function_(now_function) {} ~FakeTaskWaiter() override = default; @@ -60,7 +59,7 @@ class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter { } private: - const platform::ClockNowFunctionPtr now_function_; + const ClockNowFunctionPtr now_function_; TaskRunnerImpl* task_runner_; std::atomic<bool> has_event_{false}; std::atomic<bool> waiting_{false}; @@ -69,7 +68,7 @@ class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter { class TaskRunnerWithWaiterFactory { public: static std::unique_ptr<TaskRunnerImpl> Create( - platform::ClockNowFunctionPtr now_function) { + ClockNowFunctionPtr now_function) { fake_waiter = std::make_unique<FakeTaskWaiter>(now_function); auto runner = std::make_unique<TaskRunnerImpl>( now_function, fake_waiter.get(), std::chrono::hours(1)); @@ -86,7 +85,7 @@ std::unique_ptr<FakeTaskWaiter> TaskRunnerWithWaiterFactory::fake_waiter; } // anonymous namespace TEST(TaskRunnerImplTest, TaskRunnerExecutesTaskAndStops) { - FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))}; + FakeClock fake_clock{Clock::time_point(milliseconds(1337))}; TaskRunnerImpl runner(&fake_clock.now); std::string ran_tasks = ""; @@ -98,7 +97,7 @@ TEST(TaskRunnerImplTest, TaskRunnerExecutesTaskAndStops) { } TEST(TaskRunnerImplTest, TaskRunnerRunsDelayedTasksInOrder) { - FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))}; + FakeClock fake_clock{Clock::time_point(milliseconds(1337))}; TaskRunnerImpl runner(&fake_clock.now); std::thread t([&runner] { runner.RunUntilStopped(); }); @@ -126,7 +125,7 @@ TEST(TaskRunnerImplTest, TaskRunnerRunsDelayedTasksInOrder) { } TEST(TaskRunnerImplTest, SingleThreadedTaskRunnerRunsSequentially) { - FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))}; + FakeClock fake_clock{Clock::time_point(milliseconds(1337))}; TaskRunnerImpl runner(&fake_clock.now); std::string ran_tasks; @@ -149,7 +148,7 @@ TEST(TaskRunnerImplTest, SingleThreadedTaskRunnerRunsSequentially) { } TEST(TaskRunnerImplTest, RunsAllImmediateTasksBeforeStopping) { - FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))}; + FakeClock fake_clock{Clock::time_point(milliseconds(1337))}; TaskRunnerImpl runner(&fake_clock.now); std::string result; @@ -181,7 +180,7 @@ TEST(TaskRunnerImplTest, RunsAllImmediateTasksBeforeStopping) { } TEST(TaskRunnerImplTest, TaskRunnerIsStableWithLotsOfTasks) { - FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))}; + FakeClock fake_clock{Clock::time_point(milliseconds(1337))}; TaskRunnerImpl runner(&fake_clock.now); const int kNumberOfTasks = 500; @@ -200,7 +199,7 @@ TEST(TaskRunnerImplTest, TaskRunnerIsStableWithLotsOfTasks) { } TEST(TaskRunnerImplTest, TaskRunnerDelayedTasksDontBlockImmediateTasks) { - TaskRunnerImpl runner(platform::Clock::now); + TaskRunnerImpl runner(Clock::now); std::string ran_tasks; const auto task = [&ran_tasks] { ran_tasks += "1"; }; @@ -282,5 +281,4 @@ class RepeatedClass { std::atomic<int> execution_count{0}; }; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc index 573369a68c6..79597e1d655 100644 --- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc +++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc @@ -9,7 +9,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { bool TextTraceLoggingPlatform::IsTraceLoggingEnabled( TraceCategory::Value category) { @@ -19,12 +18,10 @@ bool TextTraceLoggingPlatform::IsTraceLoggingEnabled( } TextTraceLoggingPlatform::TextTraceLoggingPlatform() { - OSP_DCHECK(!GetTracingDestination()); StartTracing(this); } TextTraceLoggingPlatform::~TextTraceLoggingPlatform() { - OSP_DCHECK_EQ(GetTracingDestination(), this); StopTracing(); } @@ -69,5 +66,4 @@ void TextTraceLoggingPlatform::LogAsyncEnd(const uint32_t line, << timestamp << ") " << error; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h index 3d2a93c6d4f..c9155b6b580 100644 --- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h +++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h @@ -8,7 +8,6 @@ #include "platform/api/trace_logging_platform.h" namespace openscreen { -namespace platform { class TextTraceLoggingPlatform : public TraceLoggingPlatform { public: @@ -38,7 +37,6 @@ class TextTraceLoggingPlatform : public TraceLoggingPlatform { Error::Code error) override; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TEXT_TRACE_LOGGING_PLATFORM_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/time.cc b/chromium/third_party/openscreen/src/platform/impl/time.cc index 5604b84a2f1..6ad93145ec8 100644 --- a/chromium/third_party/openscreen/src/platform/impl/time.cc +++ b/chromium/third_party/openscreen/src/platform/impl/time.cc @@ -17,7 +17,6 @@ using std::chrono::steady_clock; using std::chrono::system_clock; namespace openscreen { -namespace platform { Clock::time_point Clock::now() noexcept { constexpr bool can_use_steady_clock = @@ -66,5 +65,4 @@ std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept { return std::chrono::seconds(since_epoch); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc index d784dc591ac..321bb7b32b5 100644 --- a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc @@ -11,7 +11,6 @@ using std::chrono::seconds; namespace openscreen { -namespace platform { #if __cplusplus < 202000L // Before C++20, the standard does not guarantee that time_t is the number of @@ -51,5 +50,4 @@ TEST(TimeTest, TimeTMeetsTheCpp20Standard) { } #endif -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc index dd16532421a..28c25ffe288 100644 --- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc @@ -7,7 +7,6 @@ #include <chrono> namespace openscreen { -namespace platform { struct timeval ToTimeval(const Clock::duration& timeout) { struct timeval tv; @@ -21,5 +20,4 @@ struct timeval ToTimeval(const Clock::duration& timeout) { return tv; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h index 3679682de07..8f42129ab26 100644 --- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h @@ -10,11 +10,9 @@ #include "platform/api/time.h" namespace openscreen { -namespace platform { struct timeval ToTimeval(const Clock::duration& timeout); -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TIMEVAL_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc index 3b679366c96..db82d7ad9c1 100644 --- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc @@ -7,7 +7,6 @@ #include "gtest/gtest.h" namespace openscreen { -namespace platform { TEST(TimevalPosixTest, ToTimeval) { auto timespan = Clock::duration::zero(); @@ -36,5 +35,4 @@ TEST(TimevalPosixTest, ToTimeval) { EXPECT_EQ(timeval.tv_usec, 10); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc index 8b270dabffa..c9749f728a2 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc @@ -29,13 +29,13 @@ #include "util/trace_logging.h" namespace openscreen { -namespace platform { namespace { ErrorOr<std::vector<uint8_t>> GetDEREncodedPeerCertificate(const SSL& ssl) { X509* const peer_cert = SSL_get_peer_certificate(&ssl); - ErrorOr<std::vector<uint8_t>> der_peer_cert = ExportCertificate(*peer_cert); + ErrorOr<std::vector<uint8_t>> der_peer_cert = + ExportX509CertificateToDer(*peer_cert); X509_free(peer_cert); return der_peer_cert; } @@ -60,16 +60,19 @@ TlsConnectionFactoryPosix::TlsConnectionFactoryPosix( OSP_DCHECK(task_runner_); } -TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() = default; +TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); +} // TODO(rwkeane): Add support for resuming sessions. // TODO(rwkeane): Integrate with Auth. void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address, const TlsConnectOptions& options) { - TRACE_SCOPED(TraceCategory::SSL, "TlsConnectionFactoryPosix::Connect"); + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::Connect"); IPAddress::Version version = remote_address.address.version(); std::unique_ptr<TlsConnectionPosix> connection( - new TlsConnectionPosix(version, task_runner_, platform_client_)); + new TlsConnectionPosix(version, task_runner_)); Error connect_error = connection->socket_->Connect(remote_address); if (!connect_error.ok()) { TRACE_SET_RESULT(connect_error); @@ -89,39 +92,23 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address, SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_PEER, nullptr); } - const int connection_status = SSL_connect(connection->ssl_.get()); - if (connection_status != 1) { - DispatchConnectionFailed(connection->GetRemoteEndpoint()); - TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status)); - return; - } - - ErrorOr<std::vector<uint8_t>> der_peer_cert = - GetDEREncodedPeerCertificate(*connection->ssl_); - if (!der_peer_cert) { - DispatchConnectionFailed(connection->GetRemoteEndpoint()); - TRACE_SET_RESULT(der_peer_cert.error()); - return; - } - - task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), - der = std::move(der_peer_cert.value()), - moved_connection = std::move(connection)]() mutable { - if (auto* self = weak_this.get()) { - self->client_->OnConnected(self, std::move(der), - std::move(moved_connection)); - } - }); + Connect(std::move(connection)); } void TlsConnectionFactoryPosix::SetListenCredentials( const TlsCredentials& credentials) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); EnsureInitialized(); ErrorOr<bssl::UniquePtr<X509>> cert = ImportCertificate( credentials.der_x509_cert.data(), credentials.der_x509_cert.size()); - if (!cert || - SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1) { + ErrorOr<bssl::UniquePtr<EVP_PKEY>> pkey = + ImportRSAPrivateKey(credentials.der_rsa_private_key.data(), + credentials.der_rsa_private_key.size()); + + if (!cert || !pkey || + SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1 || + SSL_CTX_use_PrivateKey(ssl_context_.get(), pkey.value().get()) != 1) { DispatchError(Error::Code::kSocketListenFailure); TRACE_SET_RESULT(Error::Code::kSocketListenFailure); return; @@ -132,15 +119,23 @@ void TlsConnectionFactoryPosix::SetListenCredentials( void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address, const TlsListenOptions& options) { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); // Credentials must be set before Listen() is called. OSP_DCHECK(listen_credentials_set_); auto socket = std::make_unique<StreamSocketPosix>(local_address); + socket->Bind(); socket->Listen(options.backlog_size); + if (socket->state() == SocketState::kClosed) { + DispatchError(Error::Code::kSocketListenFailure); + TRACE_SET_RESULT(Error::Code::kSocketListenFailure); + return; + } + OSP_DCHECK(socket->state() == SocketState::kNotConnected); OSP_DCHECK(platform_client_); if (platform_client_) { - platform_client_->tls_data_router()->RegisterSocketObserver( + platform_client_->tls_data_router()->RegisterAcceptObserver( std::move(socket), this); } } @@ -175,38 +170,16 @@ void TlsConnectionFactoryPosix::OnSocketAccepted( std::unique_ptr<StreamSocket> socket) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - TRACE_SCOPED(TraceCategory::SSL, + TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::OnSocketAccepted"); - std::unique_ptr<TlsConnectionPosix> connection(new TlsConnectionPosix( - std::move(socket), task_runner_, platform_client_)); + std::unique_ptr<TlsConnectionPosix> connection( + new TlsConnectionPosix(std::move(socket), task_runner_)); if (!ConfigureSsl(connection.get())) { return; } - const int connection_status = SSL_accept(connection->ssl_.get()); - if (connection_status != 1) { - DispatchConnectionFailed(connection->GetRemoteEndpoint()); - TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status)); - return; - } - - ErrorOr<std::vector<uint8_t>> der_peer_cert = - GetDEREncodedPeerCertificate(*connection->ssl_); - if (!der_peer_cert) { - DispatchConnectionFailed(connection->GetRemoteEndpoint()); - TRACE_SET_RESULT(der_peer_cert.error()); - return; - } - - task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), - der = std::move(der_peer_cert.value()), - moved_connection = std::move(connection)]() mutable { - if (auto* self = weak_this.get()) { - self->client_->OnAccepted(self, std::move(der), - std::move(moved_connection)); - } - }); + Accept(std::move(connection)); } bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) { @@ -258,6 +231,86 @@ void TlsConnectionFactoryPosix::Initialize() { ssl_context_.reset(context); } +void TlsConnectionFactoryPosix::Connect( + std::unique_ptr<TlsConnectionPosix> connection) { + OSP_DCHECK(connection->socket_->state() == SocketState::kConnected); + const int connection_status = SSL_connect(connection->ssl_.get()); + if (connection_status != 1) { + Error error = GetSSLError(connection->ssl_.get(), connection_status); + if (error.code() == Error::Code::kAgain) { + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + conn = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + self->Connect(std::move(conn)); + } + }); + return; + } else { + OSP_DVLOG << "SSL_connect failed with error: " << error; + DispatchConnectionFailed(connection->GetRemoteEndpoint()); + TRACE_SET_RESULT(error); + return; + } + } + + ErrorOr<std::vector<uint8_t>> der_peer_cert = + GetDEREncodedPeerCertificate(*connection->ssl_); + if (!der_peer_cert) { + DispatchConnectionFailed(connection->GetRemoteEndpoint()); + TRACE_SET_RESULT(der_peer_cert.error()); + return; + } + + connection->RegisterConnectionWithDataRouter(platform_client_); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + der = std::move(der_peer_cert.value()), + moved_connection = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + self->client_->OnConnected(self, std::move(der), + std::move(moved_connection)); + } + }); +} + +void TlsConnectionFactoryPosix::Accept( + std::unique_ptr<TlsConnectionPosix> connection) { + OSP_DCHECK(connection->socket_->state() == SocketState::kConnected); + const int connection_status = SSL_accept(connection->ssl_.get()); + if (connection_status != 1) { + Error error = GetSSLError(connection->ssl_.get(), connection_status); + if (error.code() == Error::Code::kAgain) { + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + conn = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + self->Accept(std::move(conn)); + } + }); + return; + } else { + OSP_DVLOG << "SSL_accept failed with error: " << error; + DispatchConnectionFailed(connection->GetRemoteEndpoint()); + TRACE_SET_RESULT(error); + return; + } + } + + ErrorOr<std::vector<uint8_t>> der_peer_cert = + GetDEREncodedPeerCertificate(*connection->ssl_); + std::vector<uint8_t> der; + if (der_peer_cert) { + der = std::move(der_peer_cert.value()); + } + connection->RegisterConnectionWithDataRouter(platform_client_); + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + der = std::move(der), + moved_connection = std::move(connection)]() mutable { + if (auto* self = weak_this.get()) { + self->client_->OnAccepted(self, std::move(der), + std::move(moved_connection)); + } + }); +} + void TlsConnectionFactoryPosix::DispatchConnectionFailed( const IPEndpoint& remote_endpoint) { task_runner_->PostTask( @@ -277,5 +330,4 @@ void TlsConnectionFactoryPosix::DispatchError(Error error) { }); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h index 46ba6e25a58..13ac6b6ddab 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h @@ -14,10 +14,9 @@ #include "platform/base/error.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/tls_data_router_posix.h" -#include "platform/impl/weak_ptr.h" +#include "util/weak_ptr.h" namespace openscreen { -namespace platform { class StreamSocket; @@ -62,6 +61,11 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory, // factory. void Initialize(); + // Handles their respective SSL handshake calls. These will continue to be + // scheduled on |task_runner_| until the handshake completes. + void Connect(std::unique_ptr<TlsConnectionPosix> connection); + void Accept(std::unique_ptr<TlsConnectionPosix> connection); + // Called on any thread, to post a task to notify the Client that a connection // failure or other error has occurred. void DispatchConnectionFailed(const IPEndpoint& remote_endpoint); @@ -86,7 +90,6 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory, OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionFactoryPosix); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc index 3fa39037cd2..424f2ab49cc 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc @@ -27,46 +27,26 @@ #include "util/logging.h" namespace openscreen { -namespace platform { // TODO(jophba, rwkeane): implement write blocking/unblocking TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address, - TaskRunner* task_runner, - PlatformClientPosix* platform_client) + TaskRunner* task_runner) : task_runner_(task_runner), - platform_client_(platform_client), - socket_(std::make_unique<StreamSocketPosix>(local_address)), - buffer_(this) { + socket_(std::make_unique<StreamSocketPosix>(local_address)) { OSP_DCHECK(task_runner_); - if (platform_client_) { - platform_client_->tls_data_router()->RegisterConnection(this); - } } TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version, - TaskRunner* task_runner, - PlatformClientPosix* platform_client) + TaskRunner* task_runner) : task_runner_(task_runner), - platform_client_(platform_client), - socket_(std::make_unique<StreamSocketPosix>(version)), - buffer_(this) { + socket_(std::make_unique<StreamSocketPosix>(version)) { OSP_DCHECK(task_runner_); - if (platform_client_) { - platform_client_->tls_data_router()->RegisterConnection(this); - } } TlsConnectionPosix::TlsConnectionPosix(std::unique_ptr<StreamSocket> socket, - TaskRunner* task_runner, - PlatformClientPosix* platform_client) - : task_runner_(task_runner), - platform_client_(platform_client), - socket_(std::move(socket)), - buffer_(this) { + TaskRunner* task_runner) + : task_runner_(task_runner), socket_(std::move(socket)) { OSP_DCHECK(task_runner_); - if (platform_client_) { - platform_client_->tls_data_router()->RegisterConnection(this); - } } TlsConnectionPosix::~TlsConnectionPosix() { @@ -76,47 +56,43 @@ TlsConnectionPosix::~TlsConnectionPosix() { } void TlsConnectionPosix::TryReceiveMessage() { - const int bytes_available = SSL_pending(ssl_.get()); - if (bytes_available > 0) { - // NOTE: the pending size of the data block available is not a guarantee - // that it will receive only bytes_available or even - // any data, since not all pending bytes are application data. - std::vector<uint8_t> block(bytes_available); - - const int bytes_read = SSL_read(ssl_.get(), block.data(), bytes_available); - - // Read operator was not successful, either due to a closed connection, - // an error occurred, or we have to take an action. - if (bytes_read <= 0) { - const Error error = GetSSLError(ssl_.get(), bytes_read); - if (!error.ok() && (error != Error::Code::kAgain)) { - DispatchError(error); - } - return; + OSP_DCHECK(ssl_); + constexpr int kMaxApplicationDataBytes = 4096; + std::vector<uint8_t> block(kMaxApplicationDataBytes); + const int bytes_read = + SSL_read(ssl_.get(), block.data(), kMaxApplicationDataBytes); + + // Read operator was not successful, either due to a closed connection, + // no application data available, an error occurred, or we have to take an + // action. + if (bytes_read <= 0) { + const Error error = GetSSLError(ssl_.get(), bytes_read); + if (!error.ok() && (error != Error::Code::kAgain)) { + DispatchError(error); } + return; + } - block.resize(bytes_read); + block.resize(bytes_read); - task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), - moved_block = std::move(block)]() mutable { - if (auto* self = weak_this.get()) { - if (auto* client = self->client_) { - client->OnRead(self, std::move(moved_block)); - } + task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(), + moved_block = std::move(block)]() mutable { + if (auto* self = weak_this.get()) { + if (auto* client = self->client_) { + client->OnRead(self, std::move(moved_block)); } - }); - } + } + }); } void TlsConnectionPosix::SetClient(Client* client) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); client_ = client; - notified_client_buffer_is_blocked_ = false; } -void TlsConnectionPosix::Write(const void* data, size_t len) { +bool TlsConnectionPosix::Send(const void* data, size_t len) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - buffer_.Write(data, len); + return buffer_.Push(data, len); } IPEndpoint TlsConnectionPosix::GetLocalEndpoint() const { @@ -135,29 +111,11 @@ IPEndpoint TlsConnectionPosix::GetRemoteEndpoint() const { return endpoint.value(); } -void TlsConnectionPosix::NotifyWriteBufferFill(double fraction) { - // WARNING: This method is called on multiple threads. - // - // The following is very subtle/complex behavior: Only "writes" can increase - // the buffer fill, so we expect transitions into the "blocked" state to occur - // on the |task_runner_| thread, and |client_| will be notified - // *synchronously* when that happens. Likewise, only "reads" can cause - // transitions to the "unblocked" state; but these will not occur on the - // |task_runner_| thread. Thus, when unblocking, the |client_| will be - // notified *asynchronously*; but, that should be acceptable because it's only - // a race towards a buffer overrun that is of concern. - // - // TODO(rwkeane): Have Write() return a bool, and then none of this is needed. - constexpr double kBlockBufferPercentage = 0.5; - if (fraction > kBlockBufferPercentage && - !notified_client_buffer_is_blocked_) { - NotifyClientOfWriteBlockStatusSequentially(true); - } else if (fraction < kBlockBufferPercentage && - notified_client_buffer_is_blocked_) { - NotifyClientOfWriteBlockStatusSequentially(false); - } else if (fraction >= 0.99) { - DispatchError(Error::Code::kInsufficientBuffer); - } +void TlsConnectionPosix::RegisterConnectionWithDataRouter( + PlatformClientPosix* platform_client) { + OSP_DCHECK(!platform_client_); + platform_client_ = platform_client; + platform_client_->tls_data_router()->RegisterConnection(this); } void TlsConnectionPosix::SendAvailableBytes() { @@ -189,38 +147,4 @@ void TlsConnectionPosix::DispatchError(Error error) { }); } -void TlsConnectionPosix::NotifyClientOfWriteBlockStatusSequentially( - bool is_blocked) { - if (!task_runner_->IsRunningOnTaskRunner()) { - task_runner_->PostTask( - [weak_this = weak_factory_.GetWeakPtr(), is_blocked = is_blocked] { - if (auto* self = weak_this.get()) { - OSP_DCHECK(self->task_runner_->IsRunningOnTaskRunner()); - self->NotifyClientOfWriteBlockStatusSequentially(is_blocked); - } - }); - return; - } - - OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - - if (!client_) { - return; - } - - // Check again, now that the block/unblock state change is happening - // in-sequence (it originated from parallel executions). - if (notified_client_buffer_is_blocked_ == is_blocked) { - return; - } - - notified_client_buffer_is_blocked_ = is_blocked; - if (is_blocked) { - client_->OnWriteBlocked(this); - } else { - client_->OnWriteUnblocked(this); - } -} - -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h index c0089844163..c78bf5f14d3 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h @@ -7,23 +7,20 @@ #include <openssl/ssl.h> -#include <atomic> #include <memory> #include "platform/api/tls_connection.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/stream_socket_posix.h" #include "platform/impl/tls_write_buffer.h" -#include "platform/impl/weak_ptr.h" +#include "util/weak_ptr.h" namespace openscreen { -namespace platform { class TaskRunner; class TlsConnectionFactoryPosix; -class TlsConnectionPosix : public TlsConnection, - public TlsWriteBuffer::Observer { +class TlsConnectionPosix : public TlsConnection { public: ~TlsConnectionPosix() override; @@ -36,49 +33,37 @@ class TlsConnectionPosix : public TlsConnection, // TlsConnection overrides. void SetClient(Client* client) override; - void Write(const void* data, size_t len) override; + bool Send(const void* data, size_t len) override; IPEndpoint GetLocalEndpoint() const override; IPEndpoint GetRemoteEndpoint() const override; - // TlsWriteBuffer::Observer overrides. - void NotifyWriteBufferFill(double fraction) override; + // Registers |this| with the platform TlsDataRouterPosix. This is called + // automatically by TlsConnectionFactoryPosix after the handshake completes. + void RegisterConnectionWithDataRouter(PlatformClientPosix* platform_client); + + const SocketHandle& socket_handle() const { return socket_->socket_handle(); } protected: friend class TlsConnectionFactoryPosix; - TlsConnectionPosix(IPEndpoint local_address, - TaskRunner* task_runner, - PlatformClientPosix* platform_client = - PlatformClientPosix::GetInstance()); - TlsConnectionPosix(IPAddress::Version version, - TaskRunner* task_runner, - PlatformClientPosix* platform_client = - PlatformClientPosix::GetInstance()); + TlsConnectionPosix(IPEndpoint local_address, TaskRunner* task_runner); + TlsConnectionPosix(IPAddress::Version version, TaskRunner* task_runner); TlsConnectionPosix(std::unique_ptr<StreamSocket> socket, - TaskRunner* task_runner, - PlatformClientPosix* platform_client = - PlatformClientPosix::GetInstance()); + TaskRunner* task_runner); private: // Called on any thread, to post a task to notify the Client that an |error| // has occurred. void DispatchError(Error error); - // Helper to call OnWriteBlocked() or OnWriteUnblocked(). If this is not - // called within a task run by |task_runner_|, it trampolines by posting a - // task to call itself back via |task_runner_|. See comments in implementation - // of NotifyWriteBufferFill() for further details. - void NotifyClientOfWriteBlockStatusSequentially(bool is_blocked); - TaskRunner* const task_runner_; - PlatformClientPosix* const platform_client_; + PlatformClientPosix* platform_client_ = nullptr; Client* client_ = nullptr; std::unique_ptr<StreamSocket> socket_; bssl::UniquePtr<SSL> ssl_; - std::atomic_bool notified_client_buffer_is_blocked_{false}; TlsWriteBuffer buffer_; WeakPtrFactory<TlsConnectionPosix> weak_factory_{this}; @@ -86,7 +71,6 @@ class TlsConnectionPosix : public TlsConnection, OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionPosix); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TLS_CONNECTION_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc index 642d4a93ee0..288f1998aec 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc @@ -9,7 +9,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { TlsDataRouterPosix::TlsDataRouterPosix( SocketHandleWaiter* waiter, @@ -21,30 +20,46 @@ TlsDataRouterPosix::~TlsDataRouterPosix() { } void TlsDataRouterPosix::RegisterConnection(TlsConnectionPosix* connection) { - // TODO(jophba, rwkeane): implement this method. - OSP_UNIMPLEMENTED(); + { + std::lock_guard<std::mutex> lock(connections_mutex_); + OSP_DCHECK(std::find(connections_.begin(), connections_.end(), + connection) == connections_.end()); + connections_.push_back(connection); + } + + waiter_->Subscribe(this, connection->socket_handle()); } void TlsDataRouterPosix::DeregisterConnection(TlsConnectionPosix* connection) { - // TODO(jophba, rwkeane): implement this method. - OSP_UNIMPLEMENTED(); + { + std::lock_guard<std::mutex> lock(connections_mutex_); + auto it = std::remove_if( + connections_.begin(), connections_.end(), + [connection](TlsConnectionPosix* conn) { return conn == connection; }); + if (it == connections_.end()) { + return; + } + connections_.erase(it, connections_.end()); + } + + waiter_->OnHandleDeletion(this, connection->socket_handle()); } -void TlsDataRouterPosix::RegisterSocketObserver( +void TlsDataRouterPosix::RegisterAcceptObserver( std::unique_ptr<StreamSocketPosix> socket, SocketObserver* observer) { OSP_DCHECK(observer); StreamSocketPosix* socket_ptr = socket.get(); { - std::unique_lock<std::mutex> lock(socket_mutex_); - watched_stream_sockets_.push_back(std::move(socket)); - socket_mappings_[socket_ptr] = observer; + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + accept_stream_sockets_.push_back(std::move(socket)); + accept_socket_mappings_[socket_ptr] = observer; } waiter_->Subscribe(this, socket_ptr->socket_handle()); } -void TlsDataRouterPosix::DeregisterSocketObserver(StreamSocketPosix* socket) { +void TlsDataRouterPosix::DeregisterAcceptObserver(StreamSocketPosix* socket) { OnSocketDestroyed(socket, false); } @@ -56,8 +71,8 @@ void TlsDataRouterPosix::OnConnectionDestroyed(TlsConnectionPosix* connection) { void TlsDataRouterPosix::OnSocketDestroyed(StreamSocketPosix* socket, bool skip_locking_for_testing) { { - std::unique_lock<std::mutex> lock(socket_mutex_); - if (!socket_mappings_.erase(socket)) { + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + if (!accept_socket_mappings_.erase(socket)) { return; } } @@ -66,74 +81,38 @@ void TlsDataRouterPosix::OnSocketDestroyed(StreamSocketPosix* socket, skip_locking_for_testing); { - std::unique_lock<std::mutex> lock(socket_mutex_); + std::unique_lock<std::mutex> lock(accept_socket_mutex_); auto it = std::find_if( - watched_stream_sockets_.begin(), watched_stream_sockets_.end(), + accept_stream_sockets_.begin(), accept_stream_sockets_.end(), [socket](const std::unique_ptr<StreamSocketPosix>& ptr) { return ptr.get() == socket; }); - OSP_DCHECK(it != watched_stream_sockets_.end()); - watched_stream_sockets_.erase(it); + OSP_DCHECK(it != accept_stream_sockets_.end()); + accept_stream_sockets_.erase(it); } } void TlsDataRouterPosix::ProcessReadyHandle( SocketHandleWaiter::SocketHandleRef handle) { - std::unique_lock<std::mutex> lock(socket_mutex_); - for (const auto& pair : socket_mappings_) { - if (pair.first->socket_handle() == handle) { - pair.second->OnConnectionPending(pair.first); - break; + { + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + for (const auto& pair : accept_socket_mappings_) { + if (pair.first->socket_handle() == handle) { + pair.second->OnConnectionPending(pair.first); + return; + } } } -} - -void TlsDataRouterPosix::PerformNetworkingOperations(Clock::duration timeout) { - Clock::time_point start_time = now_function_(); - - // TODO(rwkeane): Minimize time locked based on how RegisterConnection and - // DeregisterConnection are implimented. - std::lock_guard<std::mutex> lock(connections_mutex_); - if (connections_.empty()) { - return; - } - - NetworkingOperation current_operation = last_operation_; - std::vector<TlsConnectionPosix*>::iterator current_connection = - GetConnection(last_connection_processed_); - last_connection_processed_ = *current_connection; - do { - // Get the next (connection, mode) pair to use for processing. This allows - // for processing in a round-robin fashion, such that this call to - // PerformNetworkingOperations() will pick up where the previous call left - // off. - current_operation = GetNextMode(current_operation); - if (current_operation == NetworkingOperation::kReading) { - current_connection++; - if (current_connection == connections_.end()) { - current_connection = connections_.begin(); + { + std::lock_guard<std::mutex> lock(connections_mutex_); + for (TlsConnectionPosix* connection : connections_) { + if (connection->socket_handle() == handle) { + connection->TryReceiveMessage(); + connection->SendAvailableBytes(); + return; } } - - // Process the (connection, mode). - switch (current_operation) { - case NetworkingOperation::kReading: - (*current_connection)->TryReceiveMessage(); - break; - case NetworkingOperation::kWriting: - (*current_connection)->SendAvailableBytes(); - break; - } - - // If this (connection, mode) is where we started, exit. - if (last_connection_processed_ == *current_connection && - last_operation_ == current_operation) { - break; - } - } while (!HasTimedOut(start_time, timeout)); - - last_connection_processed_ = *current_connection; - last_operation_ = current_operation; + } } bool TlsDataRouterPosix::HasTimedOut(Clock::time_point start_time, @@ -142,34 +121,16 @@ bool TlsDataRouterPosix::HasTimedOut(Clock::time_point start_time, } void TlsDataRouterPosix::RemoveWatchedSocket(StreamSocketPosix* socket) { - std::unique_lock<std::mutex> lock(socket_mutex_); - const auto it = socket_mappings_.find(socket); - if (it != socket_mappings_.end()) { - socket_mappings_.erase(it); + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + const auto it = accept_socket_mappings_.find(socket); + if (it != accept_socket_mappings_.end()) { + accept_socket_mappings_.erase(it); } } bool TlsDataRouterPosix::IsSocketWatched(StreamSocketPosix* socket) const { - std::unique_lock<std::mutex> lock(socket_mutex_); - return socket_mappings_.find(socket) != socket_mappings_.end(); -} - -TlsDataRouterPosix::NetworkingOperation TlsDataRouterPosix::GetNextMode( - NetworkingOperation state) { - return state == NetworkingOperation::kReading ? NetworkingOperation::kWriting - : NetworkingOperation::kReading; -} - -std::vector<TlsConnectionPosix*>::iterator TlsDataRouterPosix::GetConnection( - TlsConnectionPosix* current) { - auto current_pos = - std::find(connections_.begin(), connections_.end(), current); - if (current_pos == connections_.end()) { - return connections_.begin(); - } - - return current_pos; + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + return accept_socket_mappings_.find(socket) != accept_socket_mappings_.end(); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h index fb8ac45c0d9..2549bd2e5e2 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h @@ -14,7 +14,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { class StreamSocketPosix; class TlsConnectionPosix; @@ -57,24 +56,19 @@ class TlsDataRouterPosix : public SocketHandleWaiter::Subscriber { // Deregister a TlsConnection. void DeregisterConnection(TlsConnectionPosix* connection); - // Takes ownership of a StreamSocket and registers that should be watched for - // incoming Tcp Connections with the SocketHandleWaiter. - void RegisterSocketObserver(std::unique_ptr<StreamSocketPosix> socket, + // Takes ownership of a StreamSocket and registers that it should be watched + // for incoming TCP connections with the SocketHandleWaiter. + void RegisterAcceptObserver(std::unique_ptr<StreamSocketPosix> socket, SocketObserver* observer); - // Stops watching a Tcp Connections for incoming connections. + // Stops watching a TCP socket for incoming connections. // NOTE: This will destroy the StreamSocket. - virtual void DeregisterSocketObserver(StreamSocketPosix* socket); + virtual void DeregisterAcceptObserver(StreamSocketPosix* socket); // Method to be executed on TlsConnection destruction. This is expected to // block until the networking thread is not using the provided connection. void OnConnectionDestroyed(TlsConnectionPosix* connection); - // Performs Read and Write operations for all connections or until the timeout - // has been hit, whichever is first. In the latter case, the following - // iteration will continue from wherever the previous iteration left off. - void PerformNetworkingOperations(Clock::duration timeout); - // SocketHandleWaiter::Subscriber overrides. void ProcessReadyHandle(SocketHandleWaiter::SocketHandleRef handle) override; @@ -91,48 +85,35 @@ class TlsDataRouterPosix : public SocketHandleWaiter::Subscriber { friend class TestingDataRouter; private: - enum class NetworkingOperation { kReading, kWriting }; - void OnSocketDestroyed(StreamSocketPosix* socket, bool skip_locking_for_testing); void RemoveWatchedSocket(StreamSocketPosix* socket); - // Helper methods for PerformNetworkingOperations. - NetworkingOperation GetNextMode(NetworkingOperation state); - std::vector<TlsConnectionPosix*>::iterator GetConnection( - TlsConnectionPosix* current); - SocketHandleWaiter* waiter_; // Mutex guarding connections_ vector. mutable std::mutex connections_mutex_; - // Mutex guarding socket_mappings_. - mutable std::mutex socket_mutex_; + // Mutex guarding |accept_socket_mappings_|. + mutable std::mutex accept_socket_mutex_; // Function to get the current time. std::function<Clock::time_point()> now_function_; - // Information related to how much of PerformNetworkingOperations(...) was - // completed before hitting the timeout. - NetworkingOperation last_operation_ = NetworkingOperation::kReading; - TlsConnectionPosix* last_connection_processed_ = nullptr; - // Mapping from all sockets to the observer that should be called when the // socket recognizes an incoming connection. - std::unordered_map<StreamSocketPosix*, SocketObserver*> socket_mappings_ - GUARDED_BY(socket_mutex_); + std::unordered_map<StreamSocketPosix*, SocketObserver*> + accept_socket_mappings_ GUARDED_BY(accept_socket_mutex_); // Set of all TlsConnectionPosix objects currently registered. std::vector<TlsConnectionPosix*> connections_ GUARDED_BY(connections_mutex_); // StreamSockets currently owned by this object, being watched for - std::vector<std::unique_ptr<StreamSocketPosix>> watched_stream_sockets_ - GUARDED_BY(connections_mutex_); + std::vector<std::unique_ptr<StreamSocketPosix>> accept_stream_sockets_ + GUARDED_BY(accept_socket_mutex_); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TLS_DATA_ROUTER_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc index 257ad08b3c8..f242900928e 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc @@ -4,6 +4,9 @@ #include "platform/impl/tls_data_router_posix.h" +#include <memory> +#include <utility> + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/base/ip_address.h" @@ -13,48 +16,51 @@ #include "platform/test/fake_task_runner.h" namespace openscreen { -namespace platform { +namespace { -using testing::_; -using testing::Return; +class MockNetworkWaiter final : public SocketHandleWaiter { + public: + MockNetworkWaiter() : SocketHandleWaiter(&FakeClock::now) {} -class TestingDataRouter : public TlsDataRouterPosix { + MOCK_METHOD2( + AwaitSocketsReadable, + ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&, + const Clock::duration&)); +}; + +class MockSocket : public StreamSocketPosix { public: - TestingDataRouter(SocketHandleWaiter* waiter) : TlsDataRouterPosix(waiter) {} + MockSocket(int fd) : StreamSocketPosix(IPAddress::Version::kV4), handle(fd) {} - using TlsDataRouterPosix::IsSocketWatched; - using TlsDataRouterPosix::NetworkingOperation; + const SocketHandle& socket_handle() const override { return handle; } - void DeregisterSocketObserver(StreamSocketPosix* socket) override { - TlsDataRouterPosix::OnSocketDestroyed(socket, true); - } + SocketHandle handle; +}; - bool AnySocketsWatched() { - std::unique_lock<std::mutex> lock(socket_mutex_); - return !watched_stream_sockets_.empty() && !socket_mappings_.empty(); - } +class MockConnection : public TlsConnectionPosix { + public: + explicit MockConnection(int fd, TaskRunner* task_runner) + : TlsConnectionPosix(std::make_unique<MockSocket>(fd), task_runner) {} + MOCK_METHOD0(SendAvailableBytes, void()); + MOCK_METHOD0(TryReceiveMessage, void()); +}; - MOCK_METHOD2(HasTimedOut, bool(Clock::time_point, Clock::duration)); +} // namespace - void SetLastState(NetworkingOperation last_operation) { - last_operation_ = last_operation; - } +class TestingDataRouter : public TlsDataRouterPosix { + public: + explicit TestingDataRouter(SocketHandleWaiter* waiter) + : TlsDataRouterPosix(waiter) {} - void SetLastConnection(TlsConnectionPosix* last_connection) { - last_connection_processed_ = last_connection; - } + using TlsDataRouterPosix::IsSocketWatched; - // TODO(rwkeane): Remove these methods once RegisterConnection and - // DeregisterConnection are implimented in TlsDataRouterPosix. - void AddConnectionForTesting(TlsConnectionPosix* connection) { - std::lock_guard<std::mutex> lock(connections_mutex_); - connections_.push_back(connection); + void DeregisterAcceptObserver(StreamSocketPosix* socket) override { + TlsDataRouterPosix::OnSocketDestroyed(socket, true); } - void RemoveConnectionForTesting(TlsConnectionPosix* connection) { - std::lock_guard<std::mutex> lock(connections_mutex_); - auto it = std::find(connections_.begin(), connections_.end(), connection); - connections_.erase(it); + bool AnySocketsWatched() { + std::unique_lock<std::mutex> lock(accept_socket_mutex_); + return !accept_stream_sockets_.empty() && !accept_socket_mappings_.empty(); } }; @@ -62,135 +68,66 @@ class MockObserver : public TestingDataRouter::SocketObserver { MOCK_METHOD1(OnConnectionPending, void(StreamSocketPosix*)); }; -class MockNetworkWaiter final : public SocketHandleWaiter { +class TlsNetworkingManagerPosixTest : public testing::Test { public: - MOCK_METHOD2( - AwaitSocketsReadable, - ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&, - const Clock::duration&)); -}; + TlsNetworkingManagerPosixTest() + : clock_(Clock::now()), + task_runner_(&clock_), + network_manager_(&network_waiter_) {} -class MockConnection : public TlsConnectionPosix { - public: - MockConnection() - : TlsConnectionPosix(IPAddress::Version::kV4, &task_runner), - clock(Clock::now()), - task_runner(&clock) {} - MOCK_METHOD0(SendAvailableBytes, void()); - MOCK_METHOD0(TryReceiveMessage, void()); + FakeTaskRunner* task_runner() { return &task_runner_; } + TestingDataRouter* network_manager() { return &network_manager_; } private: - FakeClock clock; - FakeTaskRunner task_runner; + FakeClock clock_; + FakeTaskRunner task_runner_; + MockNetworkWaiter network_waiter_; + TestingDataRouter network_manager_; }; -TEST(TlsNetworkingManagerPosixTest, SocketsWatchedCorrectly) { - MockNetworkWaiter network_waiter; - TestingDataRouter network_manager(&network_waiter); +TEST_F(TlsNetworkingManagerPosixTest, SocketsWatchedCorrectly) { auto socket = std::make_unique<StreamSocketPosix>(IPAddress::Version::kV4); MockObserver observer; auto* ptr = socket.get(); - ASSERT_FALSE(network_manager.IsSocketWatched(ptr)); + ASSERT_FALSE(network_manager()->IsSocketWatched(ptr)); - network_manager.RegisterSocketObserver(std::move(socket), &observer); - ASSERT_TRUE(network_manager.IsSocketWatched(ptr)); - ASSERT_TRUE(network_manager.AnySocketsWatched()); + network_manager()->RegisterAcceptObserver(std::move(socket), &observer); + ASSERT_TRUE(network_manager()->IsSocketWatched(ptr)); + ASSERT_TRUE(network_manager()->AnySocketsWatched()); - network_manager.DeregisterSocketObserver(ptr); - ASSERT_FALSE(network_manager.IsSocketWatched(ptr)); - ASSERT_FALSE(network_manager.AnySocketsWatched()); + network_manager()->DeregisterAcceptObserver(ptr); + ASSERT_FALSE(network_manager()->IsSocketWatched(ptr)); + ASSERT_FALSE(network_manager()->AnySocketsWatched()); - network_manager.DeregisterSocketObserver(ptr); - ASSERT_FALSE(network_manager.IsSocketWatched(ptr)); - ASSERT_FALSE(network_manager.AnySocketsWatched()); + network_manager()->DeregisterAcceptObserver(ptr); + ASSERT_FALSE(network_manager()->IsSocketWatched(ptr)); + ASSERT_FALSE(network_manager()->AnySocketsWatched()); } -TEST(TlsNetworkingManagerPosixTest, ExitsAfterOneCall) { - MockNetworkWaiter network_waiter; - TestingDataRouter network_manager(&network_waiter); - MockConnection connection1; - MockConnection connection2; - MockConnection connection3; - network_manager.AddConnectionForTesting(&connection1); - network_manager.AddConnectionForTesting(&connection2); - network_manager.AddConnectionForTesting(&connection3); - - EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillOnce(Return(true)); +TEST_F(TlsNetworkingManagerPosixTest, CallsReadySocket) { + MockConnection connection1(1, task_runner()); + MockConnection connection2(2, task_runner()); + MockConnection connection3(3, task_runner()); + network_manager()->RegisterConnection(&connection1); + network_manager()->RegisterConnection(&connection2); + network_manager()->RegisterConnection(&connection3); + EXPECT_CALL(connection1, SendAvailableBytes()).Times(1); - EXPECT_CALL(connection1, TryReceiveMessage()).Times(0); + EXPECT_CALL(connection1, TryReceiveMessage()).Times(1); EXPECT_CALL(connection2, SendAvailableBytes()).Times(0); EXPECT_CALL(connection2, TryReceiveMessage()).Times(0); EXPECT_CALL(connection3, SendAvailableBytes()).Times(0); EXPECT_CALL(connection3, TryReceiveMessage()).Times(0); - network_manager.PerformNetworkingOperations(Clock::duration{0}); -} + network_manager()->ProcessReadyHandle(connection1.socket_handle()); -TEST(TlsNetworkingManagerPosixTest, StartsAfterPrevious) { - MockNetworkWaiter network_waiter; - TestingDataRouter network_manager(&network_waiter); - MockConnection connection1; - MockConnection connection2; - MockConnection connection3; - network_manager.AddConnectionForTesting(&connection1); - network_manager.AddConnectionForTesting(&connection2); - network_manager.AddConnectionForTesting(&connection3); - network_manager.SetLastState( - TestingDataRouter::NetworkingOperation::kReading); - network_manager.SetLastConnection(&connection2); - - EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillOnce(Return(true)); - EXPECT_CALL(connection1, TryReceiveMessage()).Times(0); EXPECT_CALL(connection1, SendAvailableBytes()).Times(0); - EXPECT_CALL(connection2, TryReceiveMessage()).Times(0); - EXPECT_CALL(connection2, SendAvailableBytes()).Times(1); - EXPECT_CALL(connection3, TryReceiveMessage()).Times(0); - EXPECT_CALL(connection3, SendAvailableBytes()).Times(0); - network_manager.PerformNetworkingOperations(Clock::duration{0}); -} - -TEST(TlsNetworkingManagerPosixTest, HitsAllCallsOnce) { - MockNetworkWaiter network_waiter; - TestingDataRouter network_manager(&network_waiter); - MockConnection connection1; - MockConnection connection2; - MockConnection connection3; - network_manager.AddConnectionForTesting(&connection1); - network_manager.AddConnectionForTesting(&connection2); - network_manager.AddConnectionForTesting(&connection3); - - EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillRepeatedly(Return(false)); - EXPECT_CALL(connection1, TryReceiveMessage()).Times(1); - EXPECT_CALL(connection1, SendAvailableBytes()).Times(1); - EXPECT_CALL(connection2, TryReceiveMessage()).Times(1); + EXPECT_CALL(connection1, TryReceiveMessage()).Times(0); EXPECT_CALL(connection2, SendAvailableBytes()).Times(1); - EXPECT_CALL(connection3, TryReceiveMessage()).Times(1); - EXPECT_CALL(connection3, SendAvailableBytes()).Times(1); - network_manager.PerformNetworkingOperations(Clock::duration{0}); -} - -TEST(TlsNetworkingManagerPosixTest, HitsAllCallsOnceStartedInMiddle) { - MockNetworkWaiter network_waiter; - TestingDataRouter network_manager(&network_waiter); - MockConnection connection1; - MockConnection connection2; - MockConnection connection3; - network_manager.AddConnectionForTesting(&connection1); - network_manager.AddConnectionForTesting(&connection2); - network_manager.AddConnectionForTesting(&connection3); - network_manager.SetLastState( - TestingDataRouter::NetworkingOperation::kReading); - network_manager.SetLastConnection(&connection2); - - EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillRepeatedly(Return(false)); - EXPECT_CALL(connection1, TryReceiveMessage()).Times(1); - EXPECT_CALL(connection1, SendAvailableBytes()).Times(1); EXPECT_CALL(connection2, TryReceiveMessage()).Times(1); - EXPECT_CALL(connection2, SendAvailableBytes()).Times(1); - EXPECT_CALL(connection3, TryReceiveMessage()).Times(1); - EXPECT_CALL(connection3, SendAvailableBytes()).Times(1); - network_manager.PerformNetworkingOperations(Clock::duration{0}); + EXPECT_CALL(connection3, SendAvailableBytes()).Times(0); + EXPECT_CALL(connection3, TryReceiveMessage()).Times(0); + network_manager()->ProcessReadyHandle(connection2.socket_handle()); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc index f0370868906..b91891c36a9 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc @@ -11,16 +11,11 @@ #include "util/logging.h" namespace openscreen { -namespace platform { +TlsWriteBuffer::TlsWriteBuffer() = default; TlsWriteBuffer::~TlsWriteBuffer() = default; -TlsWriteBuffer::TlsWriteBuffer(TlsWriteBuffer::Observer* observer) - : observer_(observer) { - OSP_DCHECK(observer_); -} - -size_t TlsWriteBuffer::Write(const void* data, size_t len) { +bool TlsWriteBuffer::Push(const void* data, size_t len) { const size_t currently_written_bytes = bytes_written_so_far_.load(std::memory_order_relaxed); const size_t current_read_bytes = @@ -29,34 +24,32 @@ size_t TlsWriteBuffer::Write(const void* data, size_t len) { // Calculates the current size of the buffer. const size_t bytes_currently_used = currently_written_bytes - current_read_bytes; + OSP_DCHECK_LE(bytes_currently_used, kBufferSizeBytes); + if ((kBufferSizeBytes - bytes_currently_used) < len) { + return false; + } - // Calculates how many bytes out of the requested |len| can be written without - // causing a buffer overflow. - const size_t write_len = - std::min(kBufferSizeBytes - bytes_currently_used, len); - - // Calculates the number of bytes out of |write_len| to write in the first - // memcpy operation, which is either all |write_len| or the number that can be - // written before wrapping around to the beginning of the underlying array. + // Calculates the number of bytes out of |len| to write in the first memcpy + // operation, which is either all of |len| or the number that can be written + // before wrapping around to the beginning of the underlying array. const size_t current_write_index = currently_written_bytes % kBufferSizeBytes; const size_t first_write_len = - std::min(write_len, kBufferSizeBytes - current_write_index); + std::min(len, kBufferSizeBytes - current_write_index); memcpy(&buffer_[current_write_index], data, first_write_len); - // If we didn't write all |write_len| bytes in the previous memcpy, copy any + // If fewer than |len| bytes were transferred in the previous memcpy, copy any // remaining bytes to the array, starting at 0 (since the last write must have // finished at the end of the array). - if (first_write_len != write_len) { + if (first_write_len != len) { const uint8_t* new_start = static_cast<const uint8_t*>(data) + first_write_len; - memcpy(buffer_, new_start, write_len - first_write_len); + memcpy(buffer_, new_start, len - first_write_len); } // Store and return updated values. - const size_t new_write_size = currently_written_bytes + write_len; + const size_t new_write_size = currently_written_bytes + len; bytes_written_so_far_.store(new_write_size, std::memory_order_release); - NotifyWriteBufferFill(new_write_size, bytes_read_so_far_); - return write_len; + return true; } absl::Span<const uint8_t> TlsWriteBuffer::GetReadableRegion() { @@ -72,7 +65,7 @@ absl::Span<const uint8_t> TlsWriteBuffer::GetReadableRegion() { // this additional level of complexity. const size_t avail = currently_written_bytes - current_read_bytes; const size_t begin = current_read_bytes % kBufferSizeBytes; - const size_t end = std::min(begin + avail, kBufferSizeBytes - 1); + const size_t end = std::min(begin + avail, kBufferSizeBytes); return absl::Span<const uint8_t>(&buffer_[begin], end - begin); } @@ -85,16 +78,9 @@ void TlsWriteBuffer::Consume(size_t byte_count) { OSP_DCHECK_GE(currently_written_bytes - current_read_bytes, byte_count); const size_t new_read_index = current_read_bytes + byte_count; bytes_read_so_far_.store(new_read_index, std::memory_order_release); - - NotifyWriteBufferFill(currently_written_bytes, new_read_index); } -void TlsWriteBuffer::NotifyWriteBufferFill(size_t write_index, - size_t read_index) { - double fraction = - static_cast<double>(write_index - read_index) / kBufferSizeBytes; - observer_->NotifyWriteBufferFill(fraction); -} +// static +constexpr size_t TlsWriteBuffer::kBufferSizeBytes; -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h index 308b882dfc8..99ac45ac788 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h +++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h @@ -12,31 +12,20 @@ #include "platform/base/macros.h" namespace openscreen { -namespace platform { // This class is responsible for buffering TLS Write data. The approach taken by // this class is to allow for a single thread to act as a publisher of data and // for a separate thread to act as the consumer of that data. The data in -// question is written to a lockless FIFO queue. Whenever the capacity of the -// underlying array changes, Observer::NotifyWriteBufferFill() will be called. +// question is written to a lockless FIFO queue. class TlsWriteBuffer { public: - class Observer { - public: - virtual ~Observer() = default; - - // Signals that the write buffer has reached some percentage of being - // filled. NOTE: This method may be called from multiple threads. - virtual void NotifyWriteBufferFill(double fraction) = 0; - }; - - explicit TlsWriteBuffer(Observer* observer); - + TlsWriteBuffer(); ~TlsWriteBuffer(); - // Writes as much of the provided data as possible in the buffer, returning - // the number of bytes written. - size_t Write(const void* data, size_t len); + // Pushes the provided data into the buffer, returning true if successful. + // Returns false if there was insufficient space left. Either all or none of + // the data is pushed into the buffer. + bool Push(const void* data, size_t len); // Returns a subset of the readable region of data. At time of reading, more // data may be available for reading than what is represented in this Span. @@ -46,13 +35,9 @@ class TlsWriteBuffer { void Consume(size_t byte_count); // The amount of space to allocate in the buffer. - static constexpr size_t kBufferSizeBytes = 1 << 20; // 1 MB space. + static constexpr size_t kBufferSizeBytes = 1 << 19; // 0.5 MB. private: - // Signals that the write buffer has reached some percentage of being filled, - // as calculated based on the provided write and read indices. - void NotifyWriteBufferFill(size_t write_index, size_t read_index); - // Buffer where data to be written over the TLS connection is stored. uint8_t buffer_[kBufferSizeBytes]; @@ -63,12 +48,9 @@ class TlsWriteBuffer { std::atomic_size_t bytes_read_so_far_{0}; std::atomic_size_t bytes_written_so_far_{0}; - Observer* const observer_; - OSP_DISALLOW_COPY_AND_ASSIGN(TlsWriteBuffer); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_TLS_WRITE_BUFFER_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc index 9e81543336e..57326869f72 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc @@ -10,75 +10,108 @@ #include "gtest/gtest.h" namespace openscreen { -namespace platform { namespace { -class MockObserver : public TlsWriteBuffer::Observer { - public: - MOCK_METHOD1(NotifyWriteBufferFill, void(double)); -}; - -} // namespace - TEST(TlsWriteBufferTest, CheckBasicFunctionality) { - MockObserver observer; - TlsWriteBuffer buffer(&observer); + TlsWriteBuffer buffer; constexpr size_t write_size = TlsWriteBuffer::kBufferSizeBytes / 2; uint8_t write_buffer[write_size]; std::fill_n(write_buffer, write_size, uint8_t{1}); - EXPECT_CALL(observer, NotifyWriteBufferFill(0.5)).Times(1); - EXPECT_EQ(buffer.Write(write_buffer, write_size), write_size); + EXPECT_TRUE(buffer.Push(write_buffer, write_size)); absl::Span<const uint8_t> readable_data = buffer.GetReadableRegion(); - EXPECT_EQ(readable_data.size(), write_size); + ASSERT_EQ(readable_data.size(), write_size); for (size_t i = 0; i < readable_data.size(); i++) { EXPECT_EQ(readable_data[i], 1); } - EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1); buffer.Consume(write_size / 2); readable_data = buffer.GetReadableRegion(); - EXPECT_EQ(readable_data.size(), write_size / 2); + ASSERT_EQ(readable_data.size(), write_size / 2); for (size_t i = 0; i < readable_data.size(); i++) { EXPECT_EQ(readable_data[i], 1); } - EXPECT_CALL(observer, NotifyWriteBufferFill(0)).Times(1); buffer.Consume(write_size / 2); readable_data = buffer.GetReadableRegion(); - EXPECT_EQ(readable_data.size(), size_t{0}); + ASSERT_EQ(readable_data.size(), size_t{0}); + + // Test that the entire buffer can be used. + EXPECT_TRUE(buffer.Push(write_buffer, write_size)); + EXPECT_TRUE(buffer.Push(write_buffer, write_size)); + // The buffer should be 100% full at this point. Confirm that no more can be + // written. + EXPECT_FALSE(buffer.Push(write_buffer, write_size)); + EXPECT_FALSE(buffer.Push(write_buffer, 1)); } TEST(TlsWriteBufferTest, TestWrapAround) { - MockObserver observer; - TlsWriteBuffer buffer(&observer); - constexpr size_t write_size = TlsWriteBuffer::kBufferSizeBytes; - uint8_t write_buffer[write_size]; - - EXPECT_CALL(observer, NotifyWriteBufferFill(0.75)).Times(1); - constexpr size_t partial_write_size = write_size * 3 / 4; - EXPECT_EQ(buffer.Write(write_buffer, partial_write_size), partial_write_size); - - EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1); - buffer.Consume(write_size / 2); - - EXPECT_CALL(observer, NotifyWriteBufferFill(0.75)).Times(1); - EXPECT_EQ(buffer.Write(write_buffer, write_size / 2), write_size / 2); - - absl::Span<const uint8_t> readable_data = buffer.GetReadableRegion(); - - EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1); - buffer.Consume(write_size / 2); - - readable_data = buffer.GetReadableRegion(); - EXPECT_EQ(readable_data.size(), write_size / 4); - - EXPECT_CALL(observer, NotifyWriteBufferFill(0)).Times(1); - buffer.Consume(write_size / 4); + TlsWriteBuffer buffer; + constexpr size_t buffer_size = TlsWriteBuffer::kBufferSizeBytes; + uint8_t write_buffer[buffer_size]; + std::fill_n(write_buffer, buffer_size, uint8_t{1}); + + constexpr size_t partial_buffer_size = buffer_size * 3 / 4; + EXPECT_TRUE(buffer.Push(write_buffer, partial_buffer_size)); + // Buffer contents should now be: |111111111111····| + auto region = buffer.GetReadableRegion(); + auto* const buffer_begin = region.data(); + ASSERT_TRUE(buffer_begin); + EXPECT_EQ(region.size(), partial_buffer_size); + EXPECT_TRUE(std::all_of(region.begin(), region.end(), + [](uint8_t byte) { return byte == 1; })); + + buffer.Consume(buffer_size / 2); + // Buffer contents should now be: |········1111····| + region = buffer.GetReadableRegion(); + EXPECT_EQ(region.data(), buffer_begin + buffer_size / 2); + EXPECT_EQ(region.size(), buffer_size / 4); + EXPECT_TRUE(std::all_of(region.begin(), region.end(), + [](uint8_t byte) { return byte == 1; })); + + std::fill_n(write_buffer, buffer_size, uint8_t{2}); + EXPECT_TRUE(buffer.Push(write_buffer, buffer_size / 2)); + // Buffer contents should now be: |2222····11112222| + // Readable region should just be the end part. + region = buffer.GetReadableRegion(); + EXPECT_EQ(region.data(), buffer_begin + buffer_size / 2); + EXPECT_EQ(region.size(), buffer_size / 2); + EXPECT_TRUE(std::all_of(region.begin(), region.begin() + buffer_size / 4, + [](uint8_t byte) { return byte == 1; })); + EXPECT_TRUE(std::all_of(region.begin() + buffer_size / 4, region.end(), + [](uint8_t byte) { return byte == 2; })); + + buffer.Consume(buffer_size / 2); + // Buffer contents should now be: |2222············| + region = buffer.GetReadableRegion(); + EXPECT_EQ(region.data(), buffer_begin); + EXPECT_EQ(region.size(), buffer_size / 4); + EXPECT_TRUE(std::all_of(region.begin(), region.end(), + [](uint8_t byte) { return byte == 2; })); + + std::fill_n(write_buffer, buffer_size, uint8_t{3}); + // The following Push() fails (not enough room). + EXPECT_FALSE(buffer.Push(write_buffer, buffer_size)); + // Buffer contents should still be: |2222············| + EXPECT_TRUE(buffer.Push(write_buffer, buffer_size * 3 / 4)); + // Buffer contents should now be: |2222333333333333| + EXPECT_FALSE(buffer.Push(write_buffer, buffer_size)); // Not enough room. + EXPECT_FALSE(buffer.Push(write_buffer, 1)); // Not enough room. + region = buffer.GetReadableRegion(); + EXPECT_EQ(region.data(), buffer_begin); + EXPECT_EQ(region.size(), buffer_size); + EXPECT_TRUE(std::all_of(region.begin(), region.begin() + buffer_size / 4, + [](uint8_t byte) { return byte == 2; })); + EXPECT_TRUE(std::all_of(region.begin() + buffer_size / 4, region.end(), + [](uint8_t byte) { return byte == 3; })); + + buffer.Consume(buffer_size); + // Buffer contents should now be: |················| + EXPECT_TRUE(buffer.GetReadableRegion().empty()); } -} // namespace platform +} // namespace } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc index 927ab3158e6..1be7c79aa08 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc @@ -18,6 +18,7 @@ #include <sstream> #include <string> #include <type_traits> +#include <utility> #include "absl/types/optional.h" #include "platform/api/task_runner.h" @@ -26,7 +27,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { namespace { constexpr bool IsPowerOf2(uint32_t x) { @@ -84,9 +84,10 @@ const SocketHandle& UdpSocketPosix::GetHandle() const { } // static -ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(TaskRunner* task_runner, - Client* client, - const IPEndpoint& endpoint) { +ErrorOr<std::unique_ptr<UdpSocket>> UdpSocket::Create( + TaskRunner* task_runner, + Client* client, + const IPEndpoint& endpoint) { static std::atomic_bool in_create{false}; const bool in_create_local = in_create.exchange(true); OSP_DCHECK_EQ(in_create_local, false) @@ -112,8 +113,8 @@ ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(TaskRunner* task_runner, return fd.error(); } - auto socket = UdpSocketUniquePtr(static_cast<UdpSocket*>(new UdpSocketPosix( - task_runner, client, SocketHandle(fd.value()), endpoint))); + std::unique_ptr<UdpSocket> socket = std::make_unique<UdpSocketPosix>( + task_runner, client, SocketHandle(fd.value()), endpoint); in_create = false; return socket; } @@ -346,11 +347,11 @@ uint16_t GetPortFromFromSockAddr(const sockaddr_in& sa) { } IPAddress GetIPAddressFromSockAddr(const sockaddr_in6& sa) { - return IPAddress(sa.sin6_addr.s6_addr); + return IPAddress(IPAddress::Version::kV6, sa.sin6_addr.s6_addr); } IPAddress GetIPAddressFromPktInfo(const in6_pktinfo& pktinfo) { - return IPAddress(pktinfo.ipi6_addr.s6_addr); + return IPAddress(IPAddress::Version::kV6, pktinfo.ipi6_addr.s6_addr); } uint16_t GetPortFromFromSockAddr(const sockaddr_in6& sa) { @@ -618,5 +619,4 @@ void UdpSocketPosix::Close() { handle_.fd = -1; } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h index ba44d61ab21..6dc7ae538ce 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h @@ -10,10 +10,9 @@ #include "platform/base/macros.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/socket_handle_posix.h" -#include "platform/impl/weak_ptr.h" +#include "util/weak_ptr.h" namespace openscreen { -namespace platform { class UdpSocketReaderPosix; @@ -87,7 +86,6 @@ class UdpSocketPosix : public UdpSocket { OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocketPosix); }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_UDP_SOCKET_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc index 3892081b5bd..4667df883c9 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc @@ -12,7 +12,6 @@ #include "util/logging.h" namespace openscreen { -namespace platform { UdpSocketReaderPosix::UdpSocketReaderPosix(SocketHandleWaiter* waiter) : waiter_(waiter) {} @@ -34,9 +33,11 @@ void UdpSocketReaderPosix::ProcessReadyHandle(SocketHandleRef handle) { } void UdpSocketReaderPosix::OnCreate(UdpSocket* socket) { - std::lock_guard<std::mutex> lock(mutex_); UdpSocketPosix* read_socket = static_cast<UdpSocketPosix*>(socket); - sockets_.push_back(read_socket); + { + std::lock_guard<std::mutex> lock(mutex_); + sockets_.push_back(read_socket); + } waiter_->Subscribe(this, std::cref(read_socket->GetHandle())); } @@ -64,5 +65,4 @@ bool UdpSocketReaderPosix::IsMappedReadForTesting( return std::find(sockets_.begin(), sockets_.end(), socket) != sockets_.end(); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h index 5b1f615b6fb..d9e065d9526 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h @@ -16,7 +16,6 @@ #include "platform/impl/udp_socket_posix.h" namespace openscreen { -namespace platform { // This is the class responsible for watching sockets for readable data, then // calling the function associated with these sockets once that data is read. @@ -72,7 +71,6 @@ class UdpSocketReaderPosix : public SocketHandleWaiter::Subscriber { friend class TestingUdpSocketReader; }; -} // namespace platform } // namespace openscreen #endif // PLATFORM_IMPL_UDP_SOCKET_READER_POSIX_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc index a09fe29edf1..7faec354b7c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc @@ -14,7 +14,7 @@ #include "platform/test/fake_udp_socket.h" namespace openscreen { -namespace platform { +namespace { using namespace ::testing; using ::testing::_; @@ -48,10 +48,15 @@ class MockUdpSocketPosix : public UdpSocketPosix { // Mock event waiter class MockNetworkWaiter final : public SocketHandleWaiter { public: + MockNetworkWaiter() : SocketHandleWaiter(&FakeClock::now) {} + ~MockNetworkWaiter() override = default; + MOCK_METHOD2( AwaitSocketsReadable, ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&, const Clock::duration&)); + + FakeClock fake_clock{Clock::time_point{Clock::duration{1234567}}}; }; // Mock Task Runner @@ -78,6 +83,8 @@ class MockTaskRunner final : public TaskRunner { uint32_t delayed_tasks_posted; }; +} // namespace + // Class extending NetworkWaiter to allow for looking at protected data. class TestingUdpSocketReader final : public UdpSocketReaderPosix { public: @@ -129,5 +136,4 @@ TEST(UdpSocketReaderTest, UnwatchReadableSucceeds) { EXPECT_FALSE(network_waiter.IsMappedRead(socket.get())); } -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn b/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn new file mode 100644 index 00000000000..9c28a2a86f9 --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn @@ -0,0 +1,29 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +# LibFuzzer is a LLVM tool for coverage-guided fuzz testing. +# See http://www.chromium.org/developers/testing/libfuzzer + +import("//build_overrides/build.gni") + +source_set("fuzzing_engine_main") { + deps = [ + "//third_party/libfuzzer", + ] + sources = [] +} + +# A config used by all fuzzer_tests. +config("fuzzer_test_config") { + if (is_mac) { + ldflags = [ + "-Wl,-U,_LLVMFuzzerCustomMutator", + "-Wl,-U,_LLVMFuzzerInitialize", + ] + } +} + +# noop to tag seed corpus rules. +source_set("seed_corpus") { +} diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py b/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py new file mode 100755 index 00000000000..e80848ddfa2 --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python2 +# +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""Archive corpus file into zip and generate .d depfile. + +Invoked by GN from fuzzer_test.gni. +""" + +from __future__ import print_function +import argparse +import os +import sys +import warnings +import zipfile + +SEED_CORPUS_LIMIT_MB = 100 + + +def main(): + parser = argparse.ArgumentParser(description="Generate fuzzer config.") + parser.add_argument('corpus_directories', metavar='corpus_dir', type=str, + nargs='+') + parser.add_argument('--output', metavar='output_archive_name.zip', + required=True) + + args = parser.parse_args() + corpus_files = [] + seed_corpus_path = args.output + + for directory in args.corpus_directories: + if not os.path.exists(directory): + raise Exception('The given seed_corpus directory (%s) does not exist.' % + directory) + for (dirpath, _, filenames) in os.walk(directory): + for filename in filenames: + full_filename = os.path.join(dirpath, filename) + corpus_files.append(full_filename) + + with zipfile.ZipFile(seed_corpus_path, 'w') as z: + # Turn warnings into errors to interrupt the build: crbug.com/653920. + with warnings.catch_warnings(): + warnings.simplefilter("error") + for i, corpus_file in enumerate(corpus_files): + # To avoid duplication of filenames inside the archive, use numbers. + arcname = '%016d' % i + z.write(corpus_file, arcname) + + if os.path.getsize(seed_corpus_path) > SEED_CORPUS_LIMIT_MB * 1024 * 1024: + print('Seed corpus %s exceeds maximum allowed size (%d MB).' % + (seed_corpus_path, SEED_CORPUS_LIMIT_MB)) + sys.exit(-1) + + +if __name__ == '__main__': + main() diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni new file mode 100644 index 00000000000..4de38ac46e7 --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni @@ -0,0 +1,193 @@ +# Copyright 2015 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +# Defines fuzzer_test. +# +import("//build_overrides/build.gni") + +# fuzzer_test is used to define individual libfuzzer tests. +# +# Supported attributes: +# - (required) sources - fuzzer test source files +# - deps - test dependencies +# - libs - Additional libraries to link. +# - additional_configs - additional configs to be used for compilation +# - dict - a dictionary file for the fuzzer. +# - environment_variables - certain whitelisted environment variables for the +# fuzzer (AFL_DRIVER_DONT_DEFER is the only one allowed currently). +# - libfuzzer_options - options for the fuzzer (e.g. -close_fd_mask=N). +# - asan_options - AddressSanitizer options (e.g. allow_user_segv_handler=1). +# - msan_options - MemorySanitizer options. +# - ubsan_options - UndefinedBehaviorSanitizer options. +# - seed_corpus - a directory with seed corpus. +# - seed_corpus_deps - dependencies for generating the seed corpus. +# +# If use_libfuzzer gn flag is defined, then proper fuzzer would be build. +# Without use_libfuzzer or use_afl a unit-test style binary would be built on +# linux and the whole target is a no-op otherwise. +# +# The template wraps test() target with appropriate dependencies. +# If any test run-time options are present (dict or libfuzzer_options), then a +# config (.options file) file would be generated or modified in root output +# dir (next to test). +template("openscreen_fuzzer_test") { + if (is_clang && !build_with_chromium) { + assert(defined(invoker.sources), "Need sources in $target_name.") + + test_deps = [ "//testing/libfuzzer:fuzzing_engine_main" ] + test_data_deps = [] + + if (defined(invoker.deps)) { + test_deps += invoker.deps + } + if (defined(invoker.data_deps)) { + test_data_deps += invoker.data_deps + } + + if (defined(invoker.seed_corpus) || defined(invoker.seed_corpuses)) { + assert(!(defined(invoker.seed_corpus) && defined(invoker.seed_corpuses)), + "Do not use both seed_corpus and seed_corpuses for $target_name.") + + out = "$root_build_dir/$target_name" + "_seed_corpus.zip" + + seed_corpus_deps = [] + + if (defined(invoker.seed_corpus_deps)) { + seed_corpus_deps += invoker.seed_corpus_deps + } + + action(target_name + "_seed_corpus") { + script = "//testing/libfuzzer/archive_corpus.py" + + args = [ + "--output", + rebase_path(out, root_build_dir), + ] + + if (defined(invoker.seed_corpus)) { + args += [ rebase_path(invoker.seed_corpus, root_build_dir) ] + } + + if (defined(invoker.seed_corpuses)) { + foreach(seed_corpus_path, invoker.seed_corpuses) { + args += [ rebase_path(seed_corpus_path, root_build_dir) ] + } + } + + outputs = [ + out, + ] + + deps = [ "//testing/libfuzzer:seed_corpus" ] + seed_corpus_deps + } + + test_deps += [ ":" + target_name + "_seed_corpus" ] + } + + if (defined(invoker.dict) || defined(invoker.libfuzzer_options) || + defined(invoker.asan_options) || defined(invoker.msan_options) || + defined(invoker.ubsan_options) || + defined(invoker.environment_variables)) { + if (defined(invoker.dict)) { + # Copy dictionary to output. + copy(target_name + "_dict_copy") { + sources = [ + invoker.dict, + ] + outputs = [ + "$root_build_dir/" + target_name + ".dict", + ] + } + test_deps += [ ":" + target_name + "_dict_copy" ] + } + + # Generate .options file. + config_file_name = target_name + ".options" + action(config_file_name) { + script = "//testing/libfuzzer/gen_fuzzer_config.py" + args = [ + "--config", + rebase_path("$root_build_dir/" + config_file_name, root_build_dir), + ] + + if (defined(invoker.dict)) { + args += [ + "--dict", + rebase_path("$root_build_dir/" + invoker.target_name + ".dict", + root_build_dir), + ] + } + + if (defined(invoker.libfuzzer_options)) { + args += [ "--libfuzzer_options" ] + args += invoker.libfuzzer_options + } + + if (defined(invoker.asan_options)) { + args += [ "--asan_options" ] + args += invoker.asan_options + } + + if (defined(invoker.msan_options)) { + args += [ "--msan_options" ] + args += invoker.msan_options + } + + if (defined(invoker.ubsan_options)) { + args += [ "--ubsan_options" ] + args += invoker.ubsan_options + } + + if (defined(invoker.environment_variables)) { + args += [ "--environment_variables" ] + args += invoker.environment_variables + } + + outputs = [ + "$root_build_dir/$config_file_name", + ] + } + test_deps += [ ":" + config_file_name ] + } + + executable(target_name) { + forward_variables_from(invoker, + [ + "cflags", + "cflags_cc", + "check_includes", + "defines", + "include_dirs", + "sources", + "libs", + ]) + deps = test_deps + data_deps = test_data_deps + + if (defined(invoker.additional_configs)) { + configs += invoker.additional_configs + } + configs += [ "//testing/libfuzzer:fuzzer_test_config" ] + + if (defined(invoker.suppressed_configs)) { + configs -= invoker.suppressed_configs + } + + if (defined(invoker.generated_sources)) { + sources += invoker.generated_sources + } + + if (is_mac) { + sources += [ "//testing/libfuzzer/libfuzzer_exports.h" ] + } + } + } else { + # noop on unsupported platforms. + # mark attributes as used. + not_needed(invoker, "*") + + group(target_name) { + } + } +} diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py b/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py new file mode 100755 index 00000000000..bde9e146ed2 --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py @@ -0,0 +1,86 @@ +#!/usr/bin/python2 +# +# Copyright (c) 2015 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. +"""Generate or update an existing config (.options file) for libfuzzer test. + +Invoked by GN from fuzzer_test.gni. +""" + +import ConfigParser +import argparse +import os +import sys + + +def AddSectionOptions(config, section_name, options): + """Add |options| to the |section_name| section of |config|. + + Throws an + assertion error if any option in |options| does not have exactly two + elements. + """ + if not options: + return + + config.add_section(section_name) + for option_and_value in options: + assert len(option_and_value) == 2, ( + '%s is not an option, value pair' % option_and_value) + + config.set(section_name, *option_and_value) + + +def main(): + parser = argparse.ArgumentParser(description='Generate fuzzer config.') + parser.add_argument('--config', required=True) + parser.add_argument('--dict') + parser.add_argument('--libfuzzer_options', nargs='+', default=[]) + parser.add_argument('--asan_options', nargs='+', default=[]) + parser.add_argument('--msan_options', nargs='+', default=[]) + parser.add_argument('--ubsan_options', nargs='+', default=[]) + parser.add_argument( + '--environment_variables', + nargs='+', + default=[], + choices=['AFL_DRIVER_DONT_DEFER=1']) + args = parser.parse_args() + + # Script shouldn't be invoked without any arguments, but just in case. + if not (args.dict or args.libfuzzer_options or args.environment_variables or + args.asan_options or args.msan_options or args.ubsan_options): + return + + config = ConfigParser.ConfigParser() + libfuzzer_options = [] + if args.dict: + libfuzzer_options.append(('dict', os.path.basename(args.dict))) + libfuzzer_options.extend( + option.split('=') for option in args.libfuzzer_options) + + AddSectionOptions(config, 'libfuzzer', libfuzzer_options) + + AddSectionOptions(config, 'asan', + [option.split('=') for option in args.asan_options]) + + AddSectionOptions(config, 'msan', + [option.split('=') for option in args.msan_options]) + + AddSectionOptions(config, 'ubsan', + [option.split('=') for option in args.ubsan_options]) + + AddSectionOptions( + config, 'env', + [option.split('=') for option in args.environment_variables]) + + # Generate .options file. + config_path = args.config + with open(config_path, 'w') as options_file: + options_file.write( + '# This is an automatically generated config for ClusterFuzz.\n') + config.write(options_file) + + +if __name__ == '__main__': + main() diff --git a/chromium/third_party/openscreen/src/testing/util/BUILD.gn b/chromium/third_party/openscreen/src/testing/util/BUILD.gn new file mode 100644 index 00000000000..b96a6b4955d --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/util/BUILD.gn @@ -0,0 +1,17 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("util") { + testonly = true + sources = [ + "read_file.cc", + "read_file.h", + ] + + public_deps = [ + "../../third_party/abseil", + ] + + public_configs = [ "../../build:openscreen_include_dirs" ] +} diff --git a/chromium/third_party/openscreen/src/testing/util/read_file.cc b/chromium/third_party/openscreen/src/testing/util/read_file.cc new file mode 100644 index 00000000000..be0df7c80da --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/util/read_file.cc @@ -0,0 +1,34 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "testing/util/read_file.h" + +#include <stdio.h> + +namespace openscreen { + +std::string ReadEntireFileToString(absl::string_view filename) { + FILE* file = fopen(filename.data(), "r"); + if (file == nullptr) { + return {}; + } + fseek(file, 0, SEEK_END); + long file_size = ftell(file); + fseek(file, 0, SEEK_SET); + std::string contents(file_size, 0); + int bytes_read = 0; + while (bytes_read < file_size) { + size_t ret = fread(&contents[bytes_read], 1, file_size - bytes_read, file); + if (ret == 0 && ferror(file)) { + return {}; + } else { + bytes_read += ret; + } + } + fclose(file); + + return contents; +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/testing/util/read_file.h b/chromium/third_party/openscreen/src/testing/util/read_file.h new file mode 100644 index 00000000000..7203518b47d --- /dev/null +++ b/chromium/third_party/openscreen/src/testing/util/read_file.h @@ -0,0 +1,18 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef TESTING_UTIL_READ_FILE_H_ +#define TESTING_UTIL_READ_FILE_H_ + +#include <string> + +#include "absl/strings/string_view.h" + +namespace openscreen { + +std::string ReadEntireFileToString(absl::string_view filename); + +} // namespace openscreen + +#endif // TESTING_UTIL_READ_FILE_H_ diff --git a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn index 4c512c8b8f3..1c2ea384e0e 100644 --- a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn @@ -72,6 +72,8 @@ if (build_with_chromium) { "src/absl/strings/str_cat.cc", "src/absl/strings/str_cat.h", "src/absl/strings/str_join.h", + "src/absl/strings/str_replace.cc", + "src/absl/strings/str_replace.h", "src/absl/strings/str_split.cc", "src/absl/strings/str_split.h", "src/absl/strings/string_view.cc", diff --git a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn index 6bfebca586d..22e504fed83 100644 --- a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn @@ -32,18 +32,8 @@ if (build_with_chromium) { "BORINGSSL_NO_STATIC_INITIALIZER", "OPENSSL_SMALL", ] - cflags = [] + cflags = [ "-w" ] # Disable all warnings. cflags_c = [ "-std=c99" ] - cflags_cc = [] - if (is_clang) { - cflags += [ "-Wno-extra-semi" ] - cflags_cc += [ "-Wno-c++98-compat-extra-semi" ] - } - - if (is_mac) { - # Necessary since trybots have an old version of clang. - cflags += [ "-Wno-unknown-warning-option" ] - } defines += [ "_XOPEN_SOURCE=700" ] } diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn b/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn index 2cf14c00389..e620c3a6a94 100644 --- a/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn @@ -5,37 +5,7 @@ import("//third_party/protobuf/proto_library.gni") config("chromium_quic_config") { - cflags_cc = [ - "-Wno-error=attributes", - "-Wno-unused-const-variable", - ] - - if (is_clang) { - cflags_cc += [ - "-Wno-defaulted-function-deleted", - "-Wno-c++98-compat-extra-semi", - "-Wno-extra-semi", - ] - - # The clang version on the Mac OS X build bots is old, causing them to - # not recognize the defaulted-function-deleted warning. This flag allows - # us to build on both older and newer Mac OS X Clang toolchains. - if (is_mac) { - cflags_cc += [ "-Wno-unknown-warning-option" ] - } - } - - if (is_gcc) { - cflags_cc += [ - "-Wno-dangling-else", - "-Wno-return-type", - "-Wno-unused-but-set-variable", - - # Don't warn about "maybe" uninitialized. Clang doesn't include this - # in -Wall but gcc does, and it gives false positives. - "-Wno-maybe-uninitialized", - ] - } + cflags = [ "-w" ] # Disable all warnings. configs = [ "//third_party/protobuf:protobuf_config" ] @@ -56,7 +26,7 @@ source_set("chromium_quic") { ] } -executable("quic_demo_client") { +executable("quic_streaming_playback_controller") { sources = [ "demo/client.cc", "demo/delegates.cc", diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn b/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn index aefbef77369..72d3975419e 100644 --- a/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn @@ -547,7 +547,11 @@ source_set("base") { "synchronization/synchronization_buildflags.h", ] - if (is_posix || is_linux || is_mac) { + if (is_posix) { + if (target_cpu == "arm") { + cflags_cc = [ "-D_FILE_OFFSET_BITS=64" ] + } + sources += [ "../../src/base/base_paths_posix.h", "../../src/base/debug/debugger_posix.cc", @@ -617,7 +621,6 @@ source_set("base") { "../../src/base/files/file_util_mac.mm", "../../src/base/mac/authorization_util.h", "../../src/base/mac/authorization_util.mm", - "../../src/base/mac/availability.h", "../../src/base/mac/bundle_locations.h", "../../src/base/mac/bundle_locations.mm", "../../src/base/mac/call_with_eh_frame.cc", diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc b/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc index e57e97d34b2..d3eb2566708 100644 --- a/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc +++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc @@ -35,7 +35,8 @@ int main(int argc, char** argv) { if (argc < 2) { dprintf(STDERR_FILENO, - "Missing port number\nusage: demo_client <server-port>\n"); + "Missing port number\nusage: streaming_playback_controller " + "<server-port>\n"); return 1; } int port = atoi(argv[1]); diff --git a/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn b/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn index f76f9e749a1..27006582301 100644 --- a/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn @@ -77,6 +77,10 @@ if (build_with_chromium) { ":gtest_config", ] + public_deps = [ + ":gtest", + ] + include_dirs = [ "src/googlemock" ] } diff --git a/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn b/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn new file mode 100644 index 00000000000..a1a3db59394 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn @@ -0,0 +1,44 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//build_overrides/build.gni") + +config("ignore_warnings") { + if (is_clang) { + cflags_cc = [ + "-Wno-unused-result", + "-Wno-exit-time-destructors", + ] + } +} + +source_set("libfuzzer") { + sources = [ + "src/FuzzerCrossOver.cpp", + "src/FuzzerDataFlowTrace.cpp", + "src/FuzzerDriver.cpp", + "src/FuzzerExtFunctionsDlsym.cpp", + "src/FuzzerExtFunctionsWeak.cpp", + "src/FuzzerExtFunctionsWindows.cpp", + "src/FuzzerExtraCounters.cpp", + "src/FuzzerFork.cpp", + "src/FuzzerIO.cpp", + "src/FuzzerIOPosix.cpp", + "src/FuzzerIOWindows.cpp", + "src/FuzzerLoop.cpp", + "src/FuzzerMain.cpp", + "src/FuzzerMerge.cpp", + "src/FuzzerMutate.cpp", + "src/FuzzerSHA1.cpp", + "src/FuzzerTracePC.cpp", + "src/FuzzerUtil.cpp", + "src/FuzzerUtilDarwin.cpp", + "src/FuzzerUtilFuchsia.cpp", + "src/FuzzerUtilLinux.cpp", + "src/FuzzerUtilPosix.cpp", + "src/FuzzerUtilWindows.cpp", + ] + + configs += [ ":ignore_warnings" ] +} diff --git a/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn b/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn index 816e0906b6e..4e2fb8d4b0b 100644 --- a/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn @@ -3,20 +3,9 @@ # found in the LICENSE file. config("mdnsresponder_config") { - cflags_c = [ "-Wno-array-bounds" ] + cflags = [ "-w" ] # Disable all warnings. - if (is_gcc) { - cflags_c += [ - "-Wno-unused-but-set-variable", - "-Wno-unused-value", - ] - } - - if (is_clang) { - cflags_c += [ "-Wno-address-of-packed-member" ] - } - - cflags_c += [ + cflags_c = [ # We need to rename some linked symbols in order to avoid multiple # definitions. "-DMD5_Update=MD5_Update_mDNS", @@ -24,6 +13,10 @@ config("mdnsresponder_config") { "-DMD5_Final=MD5_Final_mDNS", "-DMD5_Transform=MD5_Transform_mDNS", ] + + if (target_cpu == "arm" || target_cpu == "arm64") { + cflags_c += [ "-Dmd5_block_data_order=md5_block_data_order" ] + } } source_set("core") { diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn b/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn new file mode 100644 index 00000000000..051917d2828 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn @@ -0,0 +1,14 @@ +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("mozilla") { + sources = [ + "url_parse.cc", + "url_parse.h", + "url_parse_internal.cc", + "url_parse_internal.h", + ] + + public_configs = [ "../../build:openscreen_include_dirs" ] +} diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt b/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt new file mode 100644 index 00000000000..ac40837824a --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt @@ -0,0 +1,65 @@ +Copyright 2007, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +------------------------------------------------------------------------------- + +The file url_parse.cc is based on nsURLParsers.cc from Mozilla. This file is +licensed separately as follows: + +The contents of this file are subject to the Mozilla Public License Version +1.1 (the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at +http://www.mozilla.org/MPL/ + +Software distributed under the License is distributed on an "AS IS" basis, +WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License +for the specific language governing rights and limitations under the +License. + +The Original Code is mozilla.org code. + +The Initial Developer of the Original Code is +Netscape Communications Corporation. +Portions created by the Initial Developer are Copyright (C) 1998 +the Initial Developer. All Rights Reserved. + +Contributor(s): + Darin Fisher (original author) + +Alternatively, the contents of this file may be used under the terms of +either the GNU General Public License Version 2 or later (the "GPL"), or +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"), +in which case the provisions of the GPL or the LGPL are applicable instead +of those above. If you wish to allow use of your version of this file only +under the terms of either the GPL or the LGPL, and not to allow others to +use your version of this file under the terms of the MPL, indicate your +decision by deleting the provisions above and replace them with the notice +and other provisions required by the GPL or the LGPL. If you do not delete +the provisions above, a recipient may use your version of this file under +the terms of any one of the MPL, the GPL or the LGPL. diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/README.md b/chromium/third_party/openscreen/src/third_party/mozilla/README.md new file mode 100644 index 00000000000..ed4c24d8c06 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/README.md @@ -0,0 +1,7 @@ +# url_parse + +`url_parse.{h,cc}` are based on the same files in Chromium under +`//url/third_party/mozilla` but have been slightly modified for our use case. +`url_parse_internal.{h,cc}` contains additional functions needed by the former +files but aren't provided directly. These are also ported from Chromium's +version. diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc new file mode 100644 index 00000000000..e6efd9e7a34 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc @@ -0,0 +1,858 @@ +/* Based on nsURLParsers.cc from Mozilla + * ------------------------------------- + * The contents of this file are subject to the Mozilla Public License Version + * 1.1 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.mozilla.org/MPL/ + * + * Software distributed under the License is distributed on an "AS IS" basis, + * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License + * for the specific language governing rights and limitations under the + * License. + * + * The Original Code is mozilla.org code. + * + * The Initial Developer of the Original Code is + * Netscape Communications Corporation. + * Portions created by the Initial Developer are Copyright (C) 1998 + * the Initial Developer. All Rights Reserved. + * + * Contributor(s): + * Darin Fisher (original author) + * + * Alternatively, the contents of this file may be used under the terms of + * either the GNU General Public License Version 2 or later (the "GPL"), or + * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"), + * in which case the provisions of the GPL or the LGPL are applicable instead + * of those above. If you wish to allow use of your version of this file only + * under the terms of either the GPL or the LGPL, and not to allow others to + * use your version of this file under the terms of the MPL, indicate your + * decision by deleting the provisions above and replace them with the notice + * and other provisions required by the GPL or the LGPL. If you do not delete + * the provisions above, a recipient may use your version of this file under + * the terms of any one of the MPL, the GPL or the LGPL. + * + * ***** END LICENSE BLOCK ***** */ + +#include "third_party/mozilla/url_parse.h" + +#include <assert.h> +#include <ctype.h> +#include <stdlib.h> + +#include "third_party/mozilla/url_parse_internal.h" + +namespace openscreen { +namespace { + +// Returns true if the given character is a valid digit to use in a port. +bool IsPortDigit(char ch) { + return ch >= '0' && ch <= '9'; +} + +// Returns the offset of the next authority terminator in the input starting +// from start_offset. If no terminator is found, the return value will be equal +// to spec_len. +int FindNextAuthorityTerminator(const char* spec, + int start_offset, + int spec_len) { + for (int i = start_offset; i < spec_len; i++) { + if (IsAuthorityTerminator(spec[i])) + return i; + } + return spec_len; // Not found. +} + +void ParseUserInfo(const char* spec, + const Component& user, + Component* username, + Component* password) { + // Find the first colon in the user section, which separates the username and + // password. + int colon_offset = 0; + while (colon_offset < user.len && spec[user.begin + colon_offset] != ':') + colon_offset++; + + if (colon_offset < user.len) { + // Found separator: <username>:<password> + *username = Component(user.begin, colon_offset); + *password = MakeRange(user.begin + colon_offset + 1, user.begin + user.len); + } else { + // No separator, treat everything as the username + *username = user; + *password = Component(); + } +} + +void ParseServerInfo(const char* spec, + const Component& serverinfo, + Component* hostname, + Component* port_num) { + if (serverinfo.len == 0) { + // No server info, host name is empty. + hostname->reset(); + port_num->reset(); + return; + } + + // If the host starts with a left-bracket, assume the entire host is an + // IPv6 literal. Otherwise, assume none of the host is an IPv6 literal. + // This assumption will be overridden if we find a right-bracket. + // + // Our IPv6 address canonicalization code requires both brackets to exist, + // but the ability to locate an incomplete address can still be useful. + int ipv6_terminator = spec[serverinfo.begin] == '[' ? serverinfo.end() : -1; + int colon = -1; + + // Find the last right-bracket, and the last colon. + for (int i = serverinfo.begin; i < serverinfo.end(); i++) { + switch (spec[i]) { + case ']': + ipv6_terminator = i; + break; + case ':': + colon = i; + break; + } + } + + if (colon > ipv6_terminator) { + // Found a port number: <hostname>:<port> + *hostname = MakeRange(serverinfo.begin, colon); + if (hostname->len == 0) + hostname->reset(); + *port_num = MakeRange(colon + 1, serverinfo.end()); + } else { + // No port: <hostname> + *hostname = serverinfo; + port_num->reset(); + } +} + +// Given an already-identified auth section, breaks it into its consituent +// parts. The port number will be parsed and the resulting integer will be +// filled into the given *port variable, or -1 if there is no port number or it +// is invalid. +void DoParseAuthority(const char* spec, + const Component& auth, + Component* username, + Component* password, + Component* hostname, + Component* port_num) { + assert(auth.is_valid()); + if (auth.len == 0) { + username->reset(); + password->reset(); + hostname->reset(); + port_num->reset(); + return; + } + + // Search backwards for @, which is the separator between the user info and + // the server info. + int i = auth.begin + auth.len - 1; + while (i > auth.begin && spec[i] != '@') + i--; + + if (spec[i] == '@') { + // Found user info: <user-info>@<server-info> + ParseUserInfo(spec, Component(auth.begin, i - auth.begin), username, + password); + ParseServerInfo(spec, MakeRange(i + 1, auth.begin + auth.len), hostname, + port_num); + } else { + // No user info, everything is server info. + username->reset(); + password->reset(); + ParseServerInfo(spec, auth, hostname, port_num); + } +} + +inline void FindQueryAndRefParts(const char* spec, + const Component& path, + int* query_separator, + int* ref_separator) { + int path_end = path.begin + path.len; + for (int i = path.begin; i < path_end; i++) { + switch (spec[i]) { + case '?': + // Only match the query string if it precedes the reference fragment + // and when we haven't found one already. + if (*query_separator < 0) + *query_separator = i; + break; + case '#': + // Record the first # sign only. + if (*ref_separator < 0) { + *ref_separator = i; + return; + } + break; + } + } +} + +void ParsePath(const char* spec, + const Component& path, + Component* filepath, + Component* query, + Component* ref) { + // path = [/]<segment1>/<segment2>/<...>/<segmentN>;<param>?<query>#<ref> + + // Special case when there is no path. + if (path.len == -1) { + filepath->reset(); + query->reset(); + ref->reset(); + return; + } + assert(path.len > 0); + + // Search for first occurrence of either ? or #. + int query_separator = -1; // Index of the '?' + int ref_separator = -1; // Index of the '#' + FindQueryAndRefParts(spec, path, &query_separator, &ref_separator); + + // Markers pointing to the character after each of these corresponding + // components. The code below words from the end back to the beginning, + // and will update these indices as it finds components that exist. + int file_end, query_end; + + // Ref fragment: from the # to the end of the path. + int path_end = path.begin + path.len; + if (ref_separator >= 0) { + file_end = query_end = ref_separator; + *ref = MakeRange(ref_separator + 1, path_end); + } else { + file_end = query_end = path_end; + ref->reset(); + } + + // Query fragment: everything from the ? to the next boundary (either the end + // of the path or the ref fragment). + if (query_separator >= 0) { + file_end = query_separator; + *query = MakeRange(query_separator + 1, query_end); + } else { + query->reset(); + } + + // File path: treat an empty file path as no file path. + if (file_end != path.begin) + *filepath = MakeRange(path.begin, file_end); + else + filepath->reset(); +} + +bool DoExtractScheme(const char* url, int url_len, Component* scheme) { + // Skip leading whitespace and control characters. + int begin = 0; + while (begin < url_len && ShouldTrimFromURL(url[begin])) + begin++; + if (begin == url_len) + return false; // Input is empty or all whitespace. + + // Find the first colon character. + for (int i = begin; i < url_len; i++) { + if (url[i] == ':') { + *scheme = MakeRange(begin, i); + return true; + } + } + return false; // No colon found: no scheme +} + +// Fills in all members of the Parsed structure except for the scheme. +// +// |spec| is the full spec being parsed, of length |spec_len|. +// |after_scheme| is the character immediately following the scheme (after the +// colon) where we'll begin parsing. +// +// Compatability data points. I list "host", "path" extracted: +// Input IE6 Firefox Us +// ----- -------------- -------------- -------------- +// http://foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/" +// http:foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/" +// http:/foo.com/ fail(*) "foo.com", "/" "foo.com", "/" +// http:\foo.com/ fail(*) "\foo.com", "/"(fail) "foo.com", "/" +// http:////foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/" +// +// (*) Interestingly, although IE fails to load these URLs, its history +// canonicalizer handles them, meaning if you've been to the corresponding +// "http://foo.com/" link, it will be colored. +void DoParseAfterScheme(const char* spec, + int spec_len, + int after_scheme, + Parsed* parsed) { + int num_slashes = CountConsecutiveSlashes(spec, after_scheme, spec_len); + int after_slashes = after_scheme + num_slashes; + + // First split into two main parts, the authority (username, password, host, + // and port) and the full path (path, query, and reference). + Component authority; + Component full_path; + + // Found "//<some data>", looks like an authority section. Treat everything + // from there to the next slash (or end of spec) to be the authority. Note + // that we ignore the number of slashes and treat it as the authority. + int end_auth = FindNextAuthorityTerminator(spec, after_slashes, spec_len); + authority = Component(after_slashes, end_auth - after_slashes); + + if (end_auth == spec_len) // No beginning of path found. + full_path = Component(); + else // Everything starting from the slash to the end is the path. + full_path = Component(end_auth, spec_len - end_auth); + + // Now parse those two sub-parts. + DoParseAuthority(spec, authority, &parsed->username, &parsed->password, + &parsed->host, &parsed->port); + ParsePath(spec, full_path, &parsed->path, &parsed->query, &parsed->ref); +} + +// The main parsing function for standard URLs. Standard URLs have a scheme, +// host, path, etc. +void DoParseStandardURL(const char* spec, int spec_len, Parsed* parsed) { + assert(spec_len >= 0); + + // Strip leading & trailing spaces and control characters. + int begin = 0; + TrimURL(spec, &begin, &spec_len); + + int after_scheme; + if (DoExtractScheme(spec, spec_len, &parsed->scheme)) { + after_scheme = parsed->scheme.end() + 1; // Skip past the colon. + } else { + // Say there's no scheme when there is no colon. We could also say that + // everything is the scheme. Both would produce an invalid URL, but this way + // seems less wrong in more cases. + parsed->scheme.reset(); + after_scheme = begin; + } + DoParseAfterScheme(spec, spec_len, after_scheme, parsed); +} + +void DoParseFileSystemURL(const char* spec, int spec_len, Parsed* parsed) { + assert(spec_len >= 0); + + // Get the unused parts of the URL out of the way. + parsed->username.reset(); + parsed->password.reset(); + parsed->host.reset(); + parsed->port.reset(); + parsed->path.reset(); // May use this; reset for convenience. + parsed->ref.reset(); // May use this; reset for convenience. + parsed->query.reset(); // May use this; reset for convenience. + parsed->clear_inner_parsed(); // May use this; reset for convenience. + + // Strip leading & trailing spaces and control characters. + int begin = 0; + TrimURL(spec, &begin, &spec_len); + + // Handle empty specs or ones that contain only whitespace or control chars. + if (begin == spec_len) { + parsed->scheme.reset(); + return; + } + + int inner_start = -1; + + // Extract the scheme. We also handle the case where there is no scheme. + if (DoExtractScheme(&spec[begin], spec_len - begin, &parsed->scheme)) { + // Offset the results since we gave ExtractScheme a substring. + parsed->scheme.begin += begin; + + if (parsed->scheme.end() == spec_len - 1) + return; + + inner_start = parsed->scheme.end() + 1; + } else { + // No scheme found; that's not valid for filesystem URLs. + parsed->scheme.reset(); + return; + } + + Component inner_scheme; + const char* inner_spec = &spec[inner_start]; + int inner_spec_len = spec_len - inner_start; + + if (DoExtractScheme(inner_spec, inner_spec_len, &inner_scheme)) { + // Offset the results since we gave ExtractScheme a substring. + inner_scheme.begin += inner_start; + + if (inner_scheme.end() == spec_len - 1) + return; + } else { + // No scheme found; that's not valid for filesystem URLs. + // The best we can do is return "filesystem://". + return; + } + + Parsed inner_parsed; + + if (CompareSchemeComponent(spec, inner_scheme, kFileScheme)) { + // File URLs are special. + ParseFileURL(inner_spec, inner_spec_len, &inner_parsed); + } else if (CompareSchemeComponent(spec, inner_scheme, kFileSystemScheme)) { + // Filesystem URLs don't nest. + return; + } else if (IsStandard(spec, inner_scheme)) { + // All "normal" URLs. + DoParseStandardURL(inner_spec, inner_spec_len, &inner_parsed); + } else { + return; + } + + // All members of inner_parsed need to be offset by inner_start. + // If we had any scheme that supported nesting more than one level deep, + // we'd have to recurse into the inner_parsed's inner_parsed when + // adjusting by inner_start. + inner_parsed.scheme.begin += inner_start; + inner_parsed.username.begin += inner_start; + inner_parsed.password.begin += inner_start; + inner_parsed.host.begin += inner_start; + inner_parsed.port.begin += inner_start; + inner_parsed.query.begin += inner_start; + inner_parsed.ref.begin += inner_start; + inner_parsed.path.begin += inner_start; + + // Query and ref move from inner_parsed to parsed. + parsed->query = inner_parsed.query; + inner_parsed.query.reset(); + parsed->ref = inner_parsed.ref; + inner_parsed.ref.reset(); + + parsed->set_inner_parsed(inner_parsed); + if (!inner_parsed.scheme.is_valid() || !inner_parsed.path.is_valid() || + inner_parsed.inner_parsed()) { + return; + } + + // The path in inner_parsed should start with a slash, then have a filesystem + // type followed by a slash. From the first slash up to but excluding the + // second should be what it keeps; the rest goes to parsed. If the path ends + // before the second slash, it's still pretty clear what the user meant, so + // we'll let that through. + if (!IsURLSlash(spec[inner_parsed.path.begin])) { + return; + } + int inner_path_end = inner_parsed.path.begin + 1; // skip the leading slash + while (inner_path_end < spec_len && !IsURLSlash(spec[inner_path_end])) + ++inner_path_end; + parsed->path.begin = inner_path_end; + int new_inner_path_length = inner_path_end - inner_parsed.path.begin; + parsed->path.len = inner_parsed.path.len - new_inner_path_length; + parsed->inner_parsed()->path.len = new_inner_path_length; +} + +// Initializes a path URL which is merely a scheme followed by a path. Examples +// include "about:foo" and "javascript:alert('bar');" +void DoParsePathURL(const char* spec, + int spec_len, + bool trim_path_end, + Parsed* parsed) { + // Get the non-path and non-scheme parts of the URL out of the way, we never + // use them. + parsed->username.reset(); + parsed->password.reset(); + parsed->host.reset(); + parsed->port.reset(); + parsed->path.reset(); + parsed->query.reset(); + parsed->ref.reset(); + + // Strip leading & trailing spaces and control characters. + int scheme_begin = 0; + TrimURL(spec, &scheme_begin, &spec_len, trim_path_end); + + // Handle empty specs or ones that contain only whitespace or control chars. + if (scheme_begin == spec_len) { + parsed->scheme.reset(); + parsed->path.reset(); + return; + } + + int path_begin; + // Extract the scheme, with the path being everything following. We also + // handle the case where there is no scheme. + if (ExtractScheme(&spec[scheme_begin], spec_len - scheme_begin, + &parsed->scheme)) { + // Offset the results since we gave ExtractScheme a substring. + parsed->scheme.begin += scheme_begin; + path_begin = parsed->scheme.end() + 1; + } else { + // No scheme case. + parsed->scheme.reset(); + path_begin = scheme_begin; + } + + if (path_begin == spec_len) + return; + assert(path_begin < spec_len); + + ParsePath(spec, MakeRange(path_begin, spec_len), &parsed->path, + &parsed->query, &parsed->ref); +} + +void DoParseMailtoURL(const char* spec, int spec_len, Parsed* parsed) { + assert(spec_len >= 0); + + // Get the non-path and non-scheme parts of the URL out of the way, we never + // use them. + parsed->username.reset(); + parsed->password.reset(); + parsed->host.reset(); + parsed->port.reset(); + parsed->ref.reset(); + parsed->query.reset(); // May use this; reset for convenience. + + // Strip leading & trailing spaces and control characters. + int begin = 0; + TrimURL(spec, &begin, &spec_len); + + // Handle empty specs or ones that contain only whitespace or control chars. + if (begin == spec_len) { + parsed->scheme.reset(); + parsed->path.reset(); + return; + } + + int path_begin = -1; + int path_end = -1; + + // Extract the scheme, with the path being everything following. We also + // handle the case where there is no scheme. + if (ExtractScheme(&spec[begin], spec_len - begin, &parsed->scheme)) { + // Offset the results since we gave ExtractScheme a substring. + parsed->scheme.begin += begin; + + if (parsed->scheme.end() != spec_len - 1) { + path_begin = parsed->scheme.end() + 1; + path_end = spec_len; + } + } else { + // No scheme found, just path. + parsed->scheme.reset(); + path_begin = begin; + path_end = spec_len; + } + + // Split [path_begin, path_end) into a path + query. + for (int i = path_begin; i < path_end; ++i) { + if (spec[i] == '?') { + parsed->query = MakeRange(i + 1, path_end); + path_end = i; + break; + } + } + + // For compatability with the standard URL parser, treat no path as + // -1, rather than having a length of 0 + if (path_begin == path_end) { + parsed->path.reset(); + } else { + parsed->path = MakeRange(path_begin, path_end); + } +} + +// Converts a port number in a string to an integer. We'd like to just call +// sscanf but our input is not NULL-terminated, which sscanf requires. Instead, +// we copy the digits to a small stack buffer (since we know the maximum number +// of digits in a valid port number) that we can NULL terminate. +int DoParsePort(const char* spec, const Component& component) { + // Easy success case when there is no port. + const int kMaxDigits = 5; + if (!component.is_nonempty()) + return PORT_UNSPECIFIED; + + // Skip over any leading 0s. + Component digits_comp(component.end(), 0); + for (int i = 0; i < component.len; i++) { + if (spec[component.begin + i] != '0') { + digits_comp = MakeRange(component.begin + i, component.end()); + break; + } + } + if (digits_comp.len == 0) + return 0; // All digits were 0. + + // Verify we don't have too many digits (we'll be copying to our buffer so + // we need to double-check). + if (digits_comp.len > kMaxDigits) + return PORT_INVALID; + + // Copy valid digits to the buffer. + char digits[kMaxDigits + 1]; // +1 for null terminator + for (int i = 0; i < digits_comp.len; i++) { + char ch = spec[digits_comp.begin + i]; + if (!IsPortDigit(ch)) { + // Invalid port digit, fail. + return PORT_INVALID; + } + digits[i] = static_cast<char>(ch); + } + + // Null-terminate the string and convert to integer. Since we guarantee + // only digits, atoi's lack of error handling is OK. + digits[digits_comp.len] = 0; + int port = atoi(digits); + if (port > 65535) + return PORT_INVALID; // Out of range. + return port; +} + +void DoExtractFileName(const char* spec, + const Component& path, + Component* file_name) { + // Handle empty paths: they have no file names. + if (!path.is_nonempty()) { + file_name->reset(); + return; + } + + // Extract the filename range from the path which is between + // the last slash and the following semicolon. + int file_end = path.end(); + for (int i = path.end() - 1; i >= path.begin; i--) { + if (spec[i] == ';') { + file_end = i; + } else if (IsURLSlash(spec[i])) { + // File name is everything following this character to the end + *file_name = MakeRange(i + 1, file_end); + return; + } + } + + // No slash found, this means the input was degenerate (generally paths + // will start with a slash). Let's call everything the file name. + *file_name = MakeRange(path.begin, file_end); + return; +} + +bool DoExtractQueryKeyValue(const char* spec, + Component* query, + Component* key, + Component* value) { + if (!query->is_nonempty()) + return false; + + int start = query->begin; + int cur = start; + int end = query->end(); + + // We assume the beginning of the input is the beginning of the "key" and we + // skip to the end of it. + key->begin = cur; + while (cur < end && spec[cur] != '&' && spec[cur] != '=') + cur++; + key->len = cur - key->begin; + + // Skip the separator after the key (if any). + if (cur < end && spec[cur] == '=') + cur++; + + // Find the value part. + value->begin = cur; + while (cur < end && spec[cur] != '&') + cur++; + value->len = cur - value->begin; + + // Finally skip the next separator if any + if (cur < end && spec[cur] == '&') + cur++; + + // Save the new query + *query = MakeRange(cur, end); + return true; +} + +} // namespace + +Parsed::Parsed() : potentially_dangling_markup(false), inner_parsed_(NULL) {} + +Parsed::Parsed(const Parsed& other) + : scheme(other.scheme), + username(other.username), + password(other.password), + host(other.host), + port(other.port), + path(other.path), + query(other.query), + ref(other.ref), + potentially_dangling_markup(other.potentially_dangling_markup), + inner_parsed_(NULL) { + if (other.inner_parsed_) + set_inner_parsed(*other.inner_parsed_); +} + +Parsed& Parsed::operator=(const Parsed& other) { + if (this != &other) { + scheme = other.scheme; + username = other.username; + password = other.password; + host = other.host; + port = other.port; + path = other.path; + query = other.query; + ref = other.ref; + potentially_dangling_markup = other.potentially_dangling_markup; + if (other.inner_parsed_) + set_inner_parsed(*other.inner_parsed_); + else + clear_inner_parsed(); + } + return *this; +} + +Parsed::~Parsed() { + delete inner_parsed_; +} + +int Parsed::Length() const { + if (ref.is_valid()) + return ref.end(); + return CountCharactersBefore(REF, false); +} + +int Parsed::CountCharactersBefore(ComponentType type, + bool include_delimiter) const { + if (type == SCHEME) + return scheme.begin; + + // There will be some characters after the scheme like "://" and we don't + // know how many. Search forwards for the next thing until we find one. + int cur = 0; + if (scheme.is_valid()) + cur = scheme.end() + 1; // Advance over the ':' at the end of the scheme. + + if (username.is_valid()) { + if (type <= USERNAME) + return username.begin; + cur = username.end() + 1; // Advance over the '@' or ':' at the end. + } + + if (password.is_valid()) { + if (type <= PASSWORD) + return password.begin; + cur = password.end() + 1; // Advance over the '@' at the end. + } + + if (host.is_valid()) { + if (type <= HOST) + return host.begin; + cur = host.end(); + } + + if (port.is_valid()) { + if (type < PORT || (type == PORT && include_delimiter)) + return port.begin - 1; // Back over delimiter. + if (type == PORT) + return port.begin; // Don't want delimiter counted. + cur = port.end(); + } + + if (path.is_valid()) { + if (type <= PATH) + return path.begin; + cur = path.end(); + } + + if (query.is_valid()) { + if (type < QUERY || (type == QUERY && include_delimiter)) + return query.begin - 1; // Back over delimiter. + if (type == QUERY) + return query.begin; // Don't want delimiter counted. + cur = query.end(); + } + + if (ref.is_valid()) { + if (type == REF && !include_delimiter) + return ref.begin; // Back over delimiter. + + // When there is a ref and we get here, the component we wanted was before + // this and not found, so we always know the beginning of the ref is right. + return ref.begin - 1; // Don't want delimiter counted. + } + + return cur; +} + +Component Parsed::GetContent() const { + const int begin = CountCharactersBefore(USERNAME, false); + const int len = Length() - begin; + // For compatability with the standard URL parser, we treat no content as + // -1, rather than having a length of 0 (we normally wouldn't care so + // much for these non-standard URLs). + return len ? Component(begin, len) : Component(); +} + +bool ExtractScheme(const char* url, int url_len, Component* scheme) { + return DoExtractScheme(url, url_len, scheme); +} + +// This handles everything that may be an authority terminator, including +// backslash. For special backslash handling see DoParseAfterScheme. +bool IsAuthorityTerminator(char ch) { + return IsURLSlash(ch) || ch == '?' || ch == '#'; +} + +void ExtractFileName(const char* url, + const Component& path, + Component* file_name) { + DoExtractFileName(url, path, file_name); +} + +bool ExtractQueryKeyValue(const char* url, + Component* query, + Component* key, + Component* value) { + return DoExtractQueryKeyValue(url, query, key, value); +} + +void ParseAuthority(const char* spec, + const Component& auth, + Component* username, + Component* password, + Component* hostname, + Component* port_num) { + DoParseAuthority(spec, auth, username, password, hostname, port_num); +} + +int ParsePort(const char* url, const Component& port) { + return DoParsePort(url, port); +} + +void ParseStandardURL(const char* url, int url_len, Parsed* parsed) { + DoParseStandardURL(url, url_len, parsed); +} + +void ParsePathURL(const char* url, + int url_len, + bool trim_path_end, + Parsed* parsed) { + DoParsePathURL(url, url_len, trim_path_end, parsed); +} + +void ParseFileSystemURL(const char* url, int url_len, Parsed* parsed) { + DoParseFileSystemURL(url, url_len, parsed); +} + +void ParseMailtoURL(const char* url, int url_len, Parsed* parsed) { + DoParseMailtoURL(url, url_len, parsed); +} + +void ParsePathInternal(const char* spec, + const Component& path, + Component* filepath, + Component* query, + Component* ref) { + ParsePath(spec, path, filepath, query, ref); +} + +void ParseAfterScheme(const char* spec, + int spec_len, + int after_scheme, + Parsed* parsed) { + DoParseAfterScheme(spec, spec_len, after_scheme, parsed); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h new file mode 100644 index 00000000000..70e97adf9b8 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h @@ -0,0 +1,322 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef THIRD_PARTY_MOZILLA_URL_PARSE_H_ +#define THIRD_PARTY_MOZILLA_URL_PARSE_H_ + +namespace openscreen { + +// Component ------------------------------------------------------------------ + +// Represents a substring for URL parsing. +struct Component { + Component() : begin(0), len(-1) {} + + // Normal constructor: takes an offset and a length. + Component(int b, int l) : begin(b), len(l) {} + + int end() const { return begin + len; } + + // Returns true if this component is valid, meaning the length is given. Even + // valid components may be empty to record the fact that they exist. + bool is_valid() const { return (len != -1); } + + // Returns true if the given component is specified on false, the component + // is either empty or invalid. + bool is_nonempty() const { return (len > 0); } + + void reset() { + begin = 0; + len = -1; + } + + bool operator==(const Component& other) const { + return begin == other.begin && len == other.len; + } + + int begin; // Byte offset in the string of this component. + int len; // Will be -1 if the component is unspecified. +}; + +// Helper that returns a component created with the given begin and ending +// points. The ending point is non-inclusive. +inline Component MakeRange(int begin, int end) { + return Component(begin, end - begin); +} + +// Parsed --------------------------------------------------------------------- + +// A structure that holds the identified parts of an input URL. This structure +// does NOT store the URL itself. The caller will have to store the URL text +// and its corresponding Parsed structure separately. +// +// Typical usage would be: +// +// Parsed parsed; +// Component scheme; +// if (!ExtractScheme(url, url_len, &scheme)) +// return I_CAN_NOT_FIND_THE_SCHEME_DUDE; +// +// if (IsStandardScheme(url, scheme)) // Not provided by this component +// ParseStandardURL(url, url_len, &parsed); +// else if (IsFileURL(url, scheme)) // Not provided by this component +// ParseFileURL(url, url_len, &parsed); +// else +// ParsePathURL(url, url_len, &parsed); +// +struct Parsed { + // Identifies different components. + enum ComponentType { + SCHEME, + USERNAME, + PASSWORD, + HOST, + PORT, + PATH, + QUERY, + REF, + }; + + // The default constructor is sufficient for the components, but inner_parsed_ + // requires special handling. + Parsed(); + Parsed(const Parsed&); + Parsed& operator=(const Parsed&); + ~Parsed(); + + // Returns the length of the URL (the end of the last component). + // + // Note that for some invalid, non-canonical URLs, this may not be the length + // of the string. For example "http://": the parsed structure will only + // contain an entry for the four-character scheme, and it doesn't know about + // the "://". For all other last-components, it will return the real length. + int Length() const; + + // Returns the number of characters before the given component if it exists, + // or where the component would be if it did exist. This will return the + // string length if the component would be appended to the end. + // + // Note that this can get a little funny for the port, query, and ref + // components which have a delimiter that is not counted as part of the + // component. The |include_delimiter| flag controls if you want this counted + // as part of the component or not when the component exists. + // + // This example shows the difference between the two flags for two of these + // delimited components that is present (the port and query) and one that + // isn't (the reference). The components that this flag affects are marked + // with a *. + // 0 1 2 + // 012345678901234567890 + // Example input: http://foo:80/?query + // include_delim=true, ...=false ("<-" indicates different) + // SCHEME: 0 0 + // USERNAME: 5 5 + // PASSWORD: 5 5 + // HOST: 7 7 + // *PORT: 10 11 <- + // PATH: 13 13 + // *QUERY: 14 15 <- + // *REF: 20 20 + // + int CountCharactersBefore(ComponentType type, bool include_delimiter) const; + + // Scheme without the colon: "http://foo"/ would have a scheme of "http". + // The length will be -1 if no scheme is specified ("foo.com"), or 0 if there + // is a colon but no scheme (":foo"). Note that the scheme is not guaranteed + // to start at the beginning of the string if there are preceeding whitespace + // or control characters. + Component scheme; + + // Username. Specified in URLs with an @ sign before the host. See |password| + Component username; + + // Password. The length will be -1 if unspecified, 0 if specified but empty. + // Not all URLs with a username have a password, as in "http://me@host/". + // The password is separated form the username with a colon, as in + // "http://me:secret@host/" + Component password; + + // Host name. + Component host; + + // Port number. + Component port; + + // Path, this is everything following the host name, stopping at the query of + // ref delimiter (if any). Length will be -1 if unspecified. This includes + // the preceeding slash, so the path on http://www.google.com/asdf" is + // "/asdf". As a result, it is impossible to have a 0 length path, it will + // be -1 in cases like "http://host?foo". + // Note that we treat backslashes the same as slashes. + Component path; + + // Stuff between the ? and the # after the path. This does not include the + // preceeding ? character. Length will be -1 if unspecified, 0 if there is + // a question mark but no query string. + Component query; + + // Indicated by a #, this is everything following the hash sign (not + // including it). If there are multiple hash signs, we'll use the last one. + // Length will be -1 if there is no hash sign, or 0 if there is one but + // nothing follows it. + Component ref; + + // The URL spec from the character after the scheme: until the end of the + // URL, regardless of the scheme. This is mostly useful for 'opaque' non- + // hierarchical schemes like data: and javascript: as a convient way to get + // the string with the scheme stripped off. + Component GetContent() const; + + // True if the URL's source contained a raw `<` character, and whitespace was + // removed from the URL during parsing + // + // TODO(mkwst): Link this to something in a spec if + // https://github.com/whatwg/url/pull/284 lands. + bool potentially_dangling_markup; + + // This is used for nested URL types, currently only filesystem. If you + // parse a filesystem URL, the resulting Parsed will have a nested + // inner_parsed_ to hold the parsed inner URL's component information. + // For all other url types [including the inner URL], it will be NULL. + Parsed* inner_parsed() const { return inner_parsed_; } + + void set_inner_parsed(const Parsed& inner_parsed) { + if (!inner_parsed_) + inner_parsed_ = new Parsed(inner_parsed); + else + *inner_parsed_ = inner_parsed; + } + + void clear_inner_parsed() { + if (inner_parsed_) { + delete inner_parsed_; + inner_parsed_ = nullptr; + } + } + + private: + Parsed* inner_parsed_; // This object is owned and managed by this struct. +}; + +// Initialization functions --------------------------------------------------- +// +// These functions parse the given URL, filling in all of the structure's +// components. These functions can not fail, they will always do their best +// at interpreting the input given. +// +// The string length of the URL MUST be specified, we do not check for NULLs +// at any point in the process, and will actually handle embedded NULLs. +// +// IMPORTANT: These functions do NOT hang on to the given pointer or copy it +// in any way. See the comment above the struct. +// +// The 8-bit versions require UTF-8 encoding. + +// StandardURL is for when the scheme is known to be one that has an +// authority (host) like "http". This function will not handle weird ones +// like "about:" and "javascript:", or do the right thing for "file:" URLs. +void ParseStandardURL(const char* url, int url_len, Parsed* parsed); + +// PathURL is for when the scheme is known not to have an authority (host) +// section but that aren't file URLs either. The scheme is parsed, and +// everything after the scheme is considered as the path. This is used for +// things like "about:" and "javascript:" +void ParsePathURL(const char* url, + int url_len, + bool trim_path_end, + Parsed* parsed); + +// FileURL is for file URLs. There are some special rules for interpreting +// these. +void ParseFileURL(const char* url, int url_len, Parsed* parsed); + +// Filesystem URLs are structured differently than other URLs. +void ParseFileSystemURL(const char* url, int url_len, Parsed* parsed); + +// MailtoURL is for mailto: urls. They are made up scheme,path,query +void ParseMailtoURL(const char* url, int url_len, Parsed* parsed); + +// Helper functions ----------------------------------------------------------- + +// Locates the scheme according to the URL parser's rules. This function is +// designed so the caller can find the scheme and call the correct Init* +// function according to their known scheme types. +// +// It also does not perform any validation on the scheme. +// +// This function will return true if the scheme is found and will put the +// scheme's range into *scheme. False means no scheme could be found. Note +// that a URL beginning with a colon has a scheme, but it is empty, so this +// function will return true but *scheme will = (0,0). +// +// The scheme is found by skipping spaces and control characters at the +// beginning, and taking everything from there to the first colon to be the +// scheme. The character at scheme.end() will be the colon (we may enhance +// this to handle full width colons or something, so don't count on the +// actual character value). The character at scheme.end()+1 will be the +// beginning of the rest of the URL, be it the authority or the path (or the +// end of the string). +// +// The 8-bit version requires UTF-8 encoding. +bool ExtractScheme(const char* url, int url_len, Component* scheme); + +// Returns true if ch is a character that terminates the authority segment +// of a URL. +bool IsAuthorityTerminator(char ch); + +// Does a best effort parse of input |spec|, in range |auth|. If a particular +// component is not found, it will be set to invalid. +void ParseAuthority(const char* spec, + const Component& auth, + Component* username, + Component* password, + Component* hostname, + Component* port_num); + +// Computes the integer port value from the given port component. The port +// component should have been identified by one of the init functions on +// |Parsed| for the given input url. +// +// The return value will be a positive integer between 0 and 64K, or one of +// the two special values below. +enum SpecialPort { PORT_UNSPECIFIED = -1, PORT_INVALID = -2 }; +int ParsePort(const char* url, const Component& port); + +// Extracts the range of the file name in the given url. The path must +// already have been computed by the parse function, and the matching URL +// and extracted path are provided to this function. The filename is +// defined as being everything from the last slash/backslash of the path +// to the end of the path. +// +// The file name will be empty if the path is empty or there is nothing +// following the last slash. +// +// The 8-bit version requires UTF-8 encoding. +void ExtractFileName(const char* url, + const Component& path, + Component* file_name); + +// Extract the first key/value from the range defined by |*query|. Updates +// |*query| to start at the end of the extracted key/value pair. This is +// designed for use in a loop: you can keep calling it with the same query +// object and it will iterate over all items in the query. +// +// Some key/value pairs may have the key, the value, or both be empty (for +// example, the query string "?&"). These will be returned. Note that an empty +// last parameter "foo.com?" or foo.com?a&" will not be returned, this case +// is the same as "done." +// +// The initial query component should not include the '?' (this is the default +// for parsed URLs). +// +// If no key/value are found |*key| and |*value| will be unchanged and it will +// return false. +bool ExtractQueryKeyValue(const char* url, + Component* query, + Component* key, + Component* value); + +} // namespace openscreen + +#endif // THIRD_PARTY_MOZILLA_URL_PARSE_H_ diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc new file mode 100644 index 00000000000..136bc62d34b --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc @@ -0,0 +1,87 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "third_party/mozilla//url_parse_internal.h" + +#include <ctype.h> + +#include "third_party/mozilla/url_parse.h" + +namespace openscreen { + +namespace { + +static const char* g_standard_schemes[] = { + kHttpsScheme, kHttpScheme, kFileScheme, kFtpScheme, + kWssScheme, kWsScheme, kFileSystemScheme, +}; + +} // namespace + +bool IsURLSlash(char ch) { + return ch == '/' || ch == '\\'; +} + +bool ShouldTrimFromURL(char ch) { + return ch <= ' '; +} + +void TrimURL(const char* spec, int* begin, int* len, bool trim_path_end) { + // Strip leading whitespace and control characters. + while (*begin < *len && ShouldTrimFromURL(spec[*begin])) { + (*begin)++; + } + + if (trim_path_end) { + // Strip trailing whitespace and control characters. We need the >i test + // for when the input string is all blanks; we don't want to back past the + // input. + while (*len > *begin && ShouldTrimFromURL(spec[*len - 1])) { + (*len)--; + } + } +} + +int CountConsecutiveSlashes(const char* str, int begin_offset, int str_len) { + int count = 0; + while ((begin_offset + count) < str_len && + IsURLSlash(str[begin_offset + count])) { + ++count; + } + return count; +} + +bool CompareSchemeComponent(const char* spec, + const Component& component, + const char* compare_to) { + if (!component.is_nonempty()) { + return compare_to[0] == 0; // When component is empty, match empty scheme. + } + for (int i = 0; i < component.len; ++i) { + if (tolower(spec[i]) != compare_to[i]) { + return false; + } + } + return true; +} + +bool IsStandard(const char* spec, const Component& component) { + if (!component.is_nonempty()) { + return false; + } + + constexpr int scheme_count = + sizeof(g_standard_schemes) / sizeof(g_standard_schemes[0]); + for (int i = 0; i < scheme_count; ++i) { + if (CompareSchemeComponent(spec, component, g_standard_schemes[i])) { + return true; + } + } + return false; +} + +// NOTE: Not implemented because file URLs are currently unsupported. +void ParseFileURL(const char* url, int url_len, Parsed* parsed) {} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h new file mode 100644 index 00000000000..58f9f75bc74 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h @@ -0,0 +1,50 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_ +#define THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_ + +namespace openscreen { + +struct Component; + +static constexpr char kHttpsScheme[] = "https"; +static constexpr char kHttpScheme[] = "http"; +static constexpr char kFileScheme[] = "file"; +static constexpr char kFtpScheme[] = "ftp"; +static constexpr char kWssScheme[] = "wss"; +static constexpr char kWsScheme[] = "ws"; +static constexpr char kFileSystemScheme[] = "filesystem"; +static constexpr char kMailtoScheme[] = "mailto"; + +// Returns whether the character |ch| should be treated as a slash. +bool IsURLSlash(char ch); + +// Returns whether the character |ch| can be safely removed for the URL. +bool ShouldTrimFromURL(char ch); + +// Given an already-initialized begin index and length, this shrinks the range +// to eliminate "should-be-trimmed" characters. Note that the length does *not* +// indicate the length of untrimmed data from |*begin|, but rather the position +// in the input string (so the string starts at character |*begin| in the spec, +// and goes until |*len|). +void TrimURL(const char* spec, int* begin, int* len, bool trim_path_end = true); + +// Returns the number of consecutive slashes in |str| starting from offset +// |begin_offset|. +int CountConsecutiveSlashes(const char* str, int begin_offset, int str_len); + +// Given a string and a range inside the string, compares it to the given +// lower-case |compare_to| buffer. +bool CompareSchemeComponent(const char* spec, + const Component& component, + const char* compare_to); + +// Returns whether the scheme given by (spec, component) is a standard scheme +// (i.e. https://url.spec.whatwg.org/#special-scheme). +bool IsStandard(const char* spec, const Component& component); + +} // namespace openscreen + +#endif // THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_ diff --git a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn index d11dd5277af..787d47ce0a0 100644 --- a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn @@ -2,6 +2,8 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//build_overrides/build.gni") + config("protobuf_config") { include_dirs = [ "src/src" ] defines = [ diff --git a/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn b/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn index ff8480fcba8..4521580544c 100644 --- a/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn @@ -4,6 +4,11 @@ import("//build_overrides/build.gni") +config("tinycbor_internal_config") { + defines = [ "WITHOUT_OPEN_MEMSTREAM" ] + cflags = [ "-w" ] # Disable all warnings. +} + source_set("tinycbor") { sources = [ "src/src/cbor.h", @@ -16,5 +21,5 @@ source_set("tinycbor") { "src/src/utf8_p.h", ] - defines = [ "WITHOUT_OPEN_MEMSTREAM" ] + configs += [ ":tinycbor_internal_config" ] } diff --git a/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn b/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn index 29e6004b084..457fc9a8948 100644 --- a/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn @@ -2,7 +2,19 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -source_set("zlib") { +config("zlib_config") { + include_dirs = [ "src" ] +} + +config("zlib_internal_config") { + defines = [ "ZLIB_IMPLEMENTATION" ] + cflags = [ "-w" ] # Disable all warnings. +} + +static_library("zlib") { + # Don't stomp on "libzlib" + output_name = "chrome_zlib" + sources = [ "src/adler32.c", "src/compress.c", @@ -26,30 +38,19 @@ source_set("zlib") { "src/trees.c", "src/trees.h", "src/uncompr.c", + "src/zconf.h", "src/zlib.h", "src/zutil.c", "src/zutil.h", ] - include_dirs = [ - "src", - ".", - ] + defines = [] + deps = [] - defines = [ - "HAVE_SYS_TYPES_H", - "HAVE_STDINT_H", - "HAVE_STDDEF_H", - "_LARGEFILE64_SOURCE", - "_FILE_OFFSET_BITS=64", - ] - if (is_mac) { - # NOTE: zlib wants to use fseeko() if available, but on Mac it isn't - # available. - defines += [ "NO_FSEEKO" ] - } -} + include_dirs = [ "." ] + configs += [ ":zlib_internal_config" ] -config("zlib_config") { - include_dirs = [ "src" ] + public_configs = [ ":zlib_config" ] + + allow_circular_includes_from = deps } diff --git a/chromium/third_party/openscreen/src/tools/cl-format.sh b/chromium/third_party/openscreen/src/tools/cl-format.sh deleted file mode 100755 index 6bc159e6953..00000000000 --- a/chromium/third_party/openscreen/src/tools/cl-format.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2018 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -for f in $(git diff --name-only @{u}); do - # Skip third party files, except our custom BUILD.gns - if [[ $f =~ third_party/[^\/]*/src ]]; then - continue; - fi - - # Skip statically copied Chromium QUIC build files. - if [[ $f =~ third_party/chromium_quic/build ]]; then - continue; - fi - - # Skip files deleted in this patch - if ! [[ -f $f ]]; then - continue; - fi - - # Format cpp files - if [[ $f =~ \.(cc|h)$ ]]; then - clang-format -style=file -i "$f" - fi - - # Format gn files - if [[ $f =~ \.gn$ ]]; then - gn format $f - fi - -done diff --git a/chromium/third_party/openscreen/src/tools/clang/scripts/update.py b/chromium/third_party/openscreen/src/tools/clang/scripts/update.py deleted file mode 100755 index 3774d78deac..00000000000 --- a/chromium/third_party/openscreen/src/tools/clang/scripts/update.py +++ /dev/null @@ -1,1025 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) 2012 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -# NOTE: this file is taken from the chromium repo, viewable here: -# https://cs.chromium.org/codesearch/f/chromium/src/tools/clang/scripts/update.py?cl=e485df98abd46b3f774377624ae7850e795fd4d8 - -# Some minor alterations have been made to suit our use case here. - -"""This script is used to download prebuilt clang binaries. - -It is also used by package.py to build the prebuilt clang binaries.""" - -import argparse -import distutils.spawn -import glob -import os -import pipes -import re -import shutil -import subprocess -import stat -import sys -import tarfile -import tempfile -import time -import urllib2 -import zipfile - - -# Do NOT CHANGE this if you don't know what you're doing -- see -# https://chromium.googlesource.com/chromium/src/+/master/docs/updating_clang.md -# Reverting problematic clang rolls is safe, though. -CLANG_REVISION = '352138' - -use_head_revision = bool(os.environ.get('LLVM_FORCE_HEAD_REVISION', '0') - in ('1', 'YES')) -if use_head_revision: - CLANG_REVISION = 'HEAD' - -# This is incremented when pushing a new build of Clang at the same revision. -CLANG_SUB_REVISION=2 - -PACKAGE_VERSION = "%s-%s" % (CLANG_REVISION, CLANG_SUB_REVISION) - -# Path constants. (All of these should be absolute paths.) -THIS_DIR = os.path.abspath(os.path.dirname(__file__)) -CHROMIUM_DIR = os.path.abspath(os.path.join(THIS_DIR, '..', '..', '..')) -GCLIENT_CONFIG = os.path.join(os.path.dirname(CHROMIUM_DIR), '.gclient') -THIRD_PARTY_DIR = os.path.join(CHROMIUM_DIR, 'third_party') -LLVM_DIR = os.path.join(THIRD_PARTY_DIR, 'llvm') -LLVM_BOOTSTRAP_DIR = os.path.join(THIRD_PARTY_DIR, 'llvm-bootstrap') -LLVM_BOOTSTRAP_INSTALL_DIR = os.path.join(THIRD_PARTY_DIR, - 'llvm-bootstrap-install') -CHROME_TOOLS_SHIM_DIR = os.path.join(LLVM_DIR, 'tools', 'chrometools') -LLVM_BUILD_DIR = os.path.join(CHROMIUM_DIR, 'third_party', 'llvm-build', - 'Release+Asserts') -THREADS_ENABLED_BUILD_DIR = os.path.join(LLVM_BUILD_DIR, 'threads_enabled') -COMPILER_RT_BUILD_DIR = os.path.join(LLVM_BUILD_DIR, 'compiler-rt') -CLANG_DIR = os.path.join(LLVM_DIR, 'tools', 'clang') -LLD_DIR = os.path.join(LLVM_DIR, 'tools', 'lld') -# compiler-rt is built as part of the regular LLVM build on Windows to get -# the 64-bit runtime, and out-of-tree elsewhere. -# TODO(thakis): Try to unify this. -if sys.platform == 'win32': - COMPILER_RT_DIR = os.path.join(LLVM_DIR, 'projects', 'compiler-rt') -else: - COMPILER_RT_DIR = os.path.join(LLVM_DIR, 'compiler-rt') -LIBCXX_DIR = os.path.join(LLVM_DIR, 'projects', 'libcxx') -LIBCXXABI_DIR = os.path.join(LLVM_DIR, 'projects', 'libcxxabi') -LLVM_BUILD_TOOLS_DIR = os.path.abspath( - os.path.join(LLVM_DIR, '..', 'llvm-build-tools')) -STAMP_FILE = os.path.normpath( - os.path.join(LLVM_DIR, '..', 'llvm-build', 'cr_build_revision')) -VERSION = '9.0.0' -ANDROID_NDK_DIR = os.path.join( - CHROMIUM_DIR, 'third_party', 'android_ndk') -FUCHSIA_SDK_DIR = os.path.join(CHROMIUM_DIR, 'third_party', 'fuchsia-sdk', - 'sdk') - -# URL for pre-built binaries. -CDS_URL = os.environ.get('CDS_CLANG_BUCKET_OVERRIDE', - 'https://commondatastorage.googleapis.com/chromium-browser-clang') - -LLVM_REPO_URL='https://llvm.org/svn/llvm-project' -if 'LLVM_REPO_URL' in os.environ: - LLVM_REPO_URL = os.environ['LLVM_REPO_URL'] - - - -def DownloadUrl(url, output_file): - """Download url into output_file.""" - CHUNK_SIZE = 4096 - TOTAL_DOTS = 10 - num_retries = 3 - retry_wait_s = 5 # Doubled at each retry. - - while True: - try: - sys.stdout.write('Downloading %s ' % url) - sys.stdout.write('to %s' % output_file) - sys.stdout.flush() - response = urllib2.urlopen(url) - total_size = int(response.info().getheader('Content-Length').strip()) - bytes_done = 0 - dots_printed = 0 - while True: - chunk = response.read(CHUNK_SIZE) - if not chunk: - break - output_file.write(chunk) - bytes_done += len(chunk) - num_dots = TOTAL_DOTS * bytes_done / total_size - sys.stdout.write('.' * (num_dots - dots_printed)) - sys.stdout.flush() - dots_printed = num_dots - if bytes_done != total_size: - raise urllib2.URLError("only got %d of %d bytes" % - (bytes_done, total_size)) - print ' Done.' - return - except urllib2.URLError as e: - sys.stdout.write('\n') - print e - if num_retries == 0 or isinstance(e, urllib2.HTTPError) and e.code == 404: - raise e - num_retries -= 1 - print 'Retrying in %d s ...' % retry_wait_s - time.sleep(retry_wait_s) - retry_wait_s *= 2 - - -def EnsureDirExists(path): - if not os.path.exists(path): - os.makedirs(path) - - -def DownloadAndUnpack(url, output_dir, path_prefix=None): - """Download an archive from url and extract into output_dir. If path_prefix is - not None, only extract files whose paths within the archive start with - path_prefix.""" - with tempfile.TemporaryFile() as f: - DownloadUrl(url, f) - f.seek(0) - EnsureDirExists(output_dir) - if url.endswith('.zip'): - assert path_prefix is None - zipfile.ZipFile(f).extractall(path=output_dir) - else: - t = tarfile.open(mode='r:gz', fileobj=f) - members = None - if path_prefix is not None: - members = [m for m in t.getmembers() if m.name.startswith(path_prefix)] - print("Extracting to output_dir: {}".format(output_dir)) - t.extractall(path=output_dir, members=members) - - -def ReadStampFile(path=STAMP_FILE): - """Return the contents of the stamp file, or '' if it doesn't exist.""" - try: - with open(path, 'r') as f: - return f.read().rstrip() - except IOError: - return '' - - -def WriteStampFile(s, path=STAMP_FILE): - """Write s to the stamp file.""" - EnsureDirExists(os.path.dirname(path)) - with open(path, 'w') as f: - f.write(s) - f.write('\n') - - -def GetSvnRevision(svn_repo): - """Returns current revision of the svn repo at svn_repo.""" - svn_info = subprocess.check_output('svn info ' + svn_repo, shell=True) - m = re.search(r'Revision: (\d+)', svn_info) - return m.group(1) - - -def RmTree(dir): - """Delete dir.""" - def ChmodAndRetry(func, path, _): - # Subversion can leave read-only files around. - if not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) - return func(path) - raise - - shutil.rmtree(dir, onerror=ChmodAndRetry) - - -def RmCmakeCache(dir): - """Delete CMake cache related files from dir.""" - for dirpath, dirs, files in os.walk(dir): - if 'CMakeCache.txt' in files: - os.remove(os.path.join(dirpath, 'CMakeCache.txt')) - if 'CMakeFiles' in dirs: - RmTree(os.path.join(dirpath, 'CMakeFiles')) - - -def RunCommand(command, msvc_arch=None, env=None, fail_hard=True): - """Run command and return success (True) or failure; or if fail_hard is - True, exit on failure. If msvc_arch is set, runs the command in a - shell with the msvc tools for that architecture.""" - - if msvc_arch and sys.platform == 'win32': - command = [os.path.join(GetWinSDKDir(), 'bin', 'SetEnv.cmd'), - "/" + msvc_arch, '&&'] + command - - # https://docs.python.org/2/library/subprocess.html: - # "On Unix with shell=True [...] if args is a sequence, the first item - # specifies the command string, and any additional items will be treated as - # additional arguments to the shell itself. That is to say, Popen does the - # equivalent of: - # Popen(['/bin/sh', '-c', args[0], args[1], ...])" - # - # We want to pass additional arguments to command[0], not to the shell, - # so manually join everything into a single string. - # Annoyingly, for "svn co url c:\path", pipes.quote() thinks that it should - # quote c:\path but svn can't handle quoted paths on Windows. Since on - # Windows follow-on args are passed to args[0] instead of the shell, don't - # do the single-string transformation there. - if sys.platform != 'win32': - command = ' '.join([pipes.quote(c) for c in command]) - print 'Running', command - if subprocess.call(command, env=env, shell=True) == 0: - return True - print 'Failed.' - if fail_hard: - sys.exit(1) - return False - - -def CopyFile(src, dst): - """Copy a file from src to dst.""" - print "Copying %s to %s" % (src, dst) - shutil.copy(src, dst) - - -def CopyDirectoryContents(src, dst): - """Copy the files from directory src to dst.""" - dst = os.path.realpath(dst) # realpath() in case dst ends in /.. - EnsureDirExists(dst) - for f in os.listdir(src): - CopyFile(os.path.join(src, f), dst) - - -def Checkout(name, url, dir): - """Checkout the SVN module at url into dir. Use name for the log message.""" - print "Checking out %s r%s into '%s'" % (name, CLANG_REVISION, dir) - - command = ['svn', 'checkout', '--force', url + '@' + CLANG_REVISION, dir] - if RunCommand(command, fail_hard=False): - return - - if os.path.isdir(dir): - print "Removing %s." % (dir) - RmTree(dir) - - print "Retrying." - RunCommand(command) - - -def CheckoutRepos(args): - if args.skip_checkout: - return - - Checkout('LLVM', LLVM_REPO_URL + '/llvm/trunk', LLVM_DIR) - Checkout('Clang', LLVM_REPO_URL + '/cfe/trunk', CLANG_DIR) - if True: - Checkout('LLD', LLVM_REPO_URL + '/lld/trunk', LLD_DIR) - elif os.path.exists(LLD_DIR): - # In case someone sends a tryjob that temporary adds lld to the checkout, - # make sure it's not around on future builds. - RmTree(LLD_DIR) - Checkout('compiler-rt', LLVM_REPO_URL + '/compiler-rt/trunk', COMPILER_RT_DIR) - if sys.platform == 'darwin': - # clang needs a libc++ checkout, else -stdlib=libc++ won't find includes - # (i.e. this is needed for bootstrap builds). - Checkout('libcxx', LLVM_REPO_URL + '/libcxx/trunk', LIBCXX_DIR) - # We used to check out libcxxabi on OS X; we no longer need that. - if os.path.exists(LIBCXXABI_DIR): - RmTree(LIBCXXABI_DIR) - - -def DeleteChromeToolsShim(): - OLD_SHIM_DIR = os.path.join(LLVM_DIR, 'tools', 'zzz-chrometools') - shutil.rmtree(OLD_SHIM_DIR, ignore_errors=True) - shutil.rmtree(CHROME_TOOLS_SHIM_DIR, ignore_errors=True) - - -def CreateChromeToolsShim(): - """Hooks the Chrome tools into the LLVM build. - - Several Chrome tools have dependencies on LLVM/Clang libraries. The LLVM build - detects implicit tools in the tools subdirectory, so this helper install a - shim CMakeLists.txt that forwards to the real directory for the Chrome tools. - - Note that the shim directory name intentionally has no - or _. The implicit - tool detection logic munges them in a weird way.""" - assert not any(i in os.path.basename(CHROME_TOOLS_SHIM_DIR) for i in '-_') - os.mkdir(CHROME_TOOLS_SHIM_DIR) - with file(os.path.join(CHROME_TOOLS_SHIM_DIR, 'CMakeLists.txt'), 'w') as f: - f.write('# Automatically generated by tools/clang/scripts/update.py. ' + - 'Do not edit.\n') - f.write('# Since tools/clang is located in another directory, use the \n') - f.write('# two arg version to specify where build artifacts go. CMake\n') - f.write('# disallows reuse of the same binary dir for multiple source\n') - f.write('# dirs, so the build artifacts need to go into a subdirectory.\n') - f.write('if (CHROMIUM_TOOLS_SRC)\n') - f.write(' add_subdirectory(${CHROMIUM_TOOLS_SRC} ' + - '${CMAKE_CURRENT_BINARY_DIR}/a)\n') - f.write('endif (CHROMIUM_TOOLS_SRC)\n') - - -def AddSvnToPathOnWin(): - """Download svn.exe and add it to PATH.""" - if sys.platform != 'win32': - return - svn_ver = 'svn-1.6.6-win' - svn_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, svn_ver) - if not os.path.exists(svn_dir): - DownloadAndUnpack(CDS_URL + '/tools/%s.zip' % svn_ver, LLVM_BUILD_TOOLS_DIR) - os.environ['PATH'] = svn_dir + os.pathsep + os.environ.get('PATH', '') - - -def AddCMakeToPath(args): - """Download CMake and add it to PATH.""" - if args.use_system_cmake: - return - - if sys.platform == 'win32': - zip_name = 'cmake-3.12.1-win32-x86.zip' - dir_name = ['cmake-3.12.1-win32-x86', 'bin'] - elif sys.platform == 'darwin': - zip_name = 'cmake-3.12.1-Darwin-x86_64.tar.gz' - dir_name = ['cmake-3.12.1-Darwin-x86_64', 'CMake.app', 'Contents', 'bin'] - else: - zip_name = 'cmake-3.12.1-Linux-x86_64.tar.gz' - dir_name = ['cmake-3.12.1-Linux-x86_64', 'bin'] - - cmake_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, *dir_name) - if not os.path.exists(cmake_dir): - DownloadAndUnpack(CDS_URL + '/tools/' + zip_name, LLVM_BUILD_TOOLS_DIR) - os.environ['PATH'] = cmake_dir + os.pathsep + os.environ.get('PATH', '') - - -def AddGnuWinToPath(): - """Download some GNU win tools and add them to PATH.""" - if sys.platform != 'win32': - return - - gnuwin_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, 'gnuwin') - GNUWIN_VERSION = '9' - GNUWIN_STAMP = os.path.join(gnuwin_dir, 'stamp') - if ReadStampFile(GNUWIN_STAMP) == GNUWIN_VERSION: - print 'GNU Win tools already up to date.' - else: - zip_name = 'gnuwin-%s.zip' % GNUWIN_VERSION - DownloadAndUnpack(CDS_URL + '/tools/' + zip_name, LLVM_BUILD_TOOLS_DIR) - WriteStampFile(GNUWIN_VERSION, GNUWIN_STAMP) - - os.environ['PATH'] = gnuwin_dir + os.pathsep + os.environ.get('PATH', '') - - # find.exe, mv.exe and rm.exe are from MSYS (see crrev.com/389632). MSYS uses - # Cygwin under the hood, and initializing Cygwin has a race-condition when - # getting group and user data from the Active Directory is slow. To work - # around this, use a horrible hack telling it not to do that. - # See https://crbug.com/905289 - etc = os.path.join(gnuwin_dir, '..', '..', 'etc') - EnsureDirExists(etc) - with open(os.path.join(etc, 'nsswitch.conf'), 'w') as f: - f.write('passwd: files\n') - f.write('group: files\n') - - -win_sdk_dir = None -dia_dll = None -def GetWinSDKDir(): - """Get the location of the current SDK. Sets dia_dll as a side-effect.""" - global win_sdk_dir - global dia_dll - if win_sdk_dir: - return win_sdk_dir - - # Bump after VC updates. - DIA_DLL = { - '2013': 'msdia120.dll', - '2015': 'msdia140.dll', - '2017': 'msdia140.dll', - '2019': 'msdia140.dll', - } - - # Don't let vs_toolchain overwrite our environment. - environ_bak = os.environ - - sys.path.append(os.path.join(CHROMIUM_DIR, 'build')) - import vs_toolchain - win_sdk_dir = vs_toolchain.SetEnvironmentAndGetSDKDir() - msvs_version = vs_toolchain.GetVisualStudioVersion() - - if bool(int(os.environ.get('DEPOT_TOOLS_WIN_TOOLCHAIN', '1'))): - dia_path = os.path.join(win_sdk_dir, '..', 'DIA SDK', 'bin', 'amd64') - else: - if 'GYP_MSVS_OVERRIDE_PATH' not in os.environ: - vs_path = vs_toolchain.DetectVisualStudioPath() - else: - vs_path = os.environ['GYP_MSVS_OVERRIDE_PATH'] - dia_path = os.path.join(vs_path, 'DIA SDK', 'bin', 'amd64') - - dia_dll = os.path.join(dia_path, DIA_DLL[msvs_version]) - - os.environ = environ_bak - return win_sdk_dir - - -def CopyDiaDllTo(target_dir): - # This script always wants to use the 64-bit msdia*.dll. - GetWinSDKDir() - CopyFile(dia_dll, target_dir) - - -def VeryifyVersionOfBuiltClangMatchesVERSION(): - """Checks that `clang --version` outputs VERSION. If this fails, VERSION - in this file is out-of-date and needs to be updated (possibly in an - `if use_head_revision:` block in main() first).""" - clang = os.path.join(LLVM_BUILD_DIR, 'bin', 'clang') - if sys.platform == 'win32': - clang += '.exe' - version_out = subprocess.check_output([clang, '--version']) - version_out = re.match(r'clang version ([0-9.]+)', version_out).group(1) - if version_out != VERSION: - print ('unexpected clang version %s (not %s), update VERSION in update.py' - % (version_out, VERSION)) - sys.exit(1) - - -def GetPlatformUrlPrefix(platform): - if platform == 'win32' or platform == 'cygwin': - return CDS_URL + '/Win/' - if platform == 'darwin': - return CDS_URL + '/Mac/' - assert platform.startswith('linux') - return CDS_URL + '/Linux_x64/' - - -def DownloadAndUnpackClangPackage(platform, runtimes_only=False): - cds_file = "clang-%s.tgz" % PACKAGE_VERSION - cds_full_url = GetPlatformUrlPrefix(platform) + cds_file - try: - path_prefix = None - if runtimes_only: - path_prefix = 'lib/clang/' + VERSION + '/lib/' - DownloadAndUnpack(cds_full_url, LLVM_BUILD_DIR, path_prefix) - except urllib2.URLError: - print 'Failed to download prebuilt clang %s' % cds_file - print 'Use --force-local-build if you want to build locally.' - print 'Exiting.' - sys.exit(1) - - -def UpdateClang(args): - # Read target_os from .gclient so we know which non-native runtimes we need. - # TODO(pcc): See if we can download just the runtimes instead of the entire - # clang package, and do that from DEPS instead of here. - target_os = [] - try: - env = {} - execfile(GCLIENT_CONFIG, env, env) - target_os = env.get('target_os', target_os) - except: - pass - - expected_stamp = ','.join([PACKAGE_VERSION] + target_os) - if ReadStampFile() == expected_stamp and not args.force_local_build: - return 0 - - # Reset the stamp file in case the build is unsuccessful. - WriteStampFile('') - - if not args.force_local_build: - if os.path.exists(LLVM_BUILD_DIR): - RmTree(LLVM_BUILD_DIR) - - DownloadAndUnpackClangPackage(sys.platform) - if 'win' in target_os: - DownloadAndUnpackClangPackage('win32', runtimes_only=True) - if sys.platform == 'win32': - CopyDiaDllTo(os.path.join(LLVM_BUILD_DIR, 'bin')) - WriteStampFile(expected_stamp) - return 0 - - if args.with_android and not os.path.exists(ANDROID_NDK_DIR): - print 'Android NDK not found at ' + ANDROID_NDK_DIR - print 'The Android NDK is needed to build a Clang whose -fsanitize=address' - print 'works on Android. See ' - print 'https://www.chromium.org/developers/how-tos/android-build-instructions' - print 'for how to install the NDK, or pass --without-android.' - return 1 - - if args.with_fuchsia and not os.path.exists(FUCHSIA_SDK_DIR): - print 'Fuchsia SDK not found at ' + FUCHSIA_SDK_DIR - print 'The Fuchsia SDK is needed to build libclang_rt for Fuchsia.' - print 'Install the Fuchsia SDK by adding fuchsia to the ' - print 'target_os section in your .gclient and running hooks, ' - print 'or pass --without-fuchsia.' - print 'https://chromium.googlesource.com/chromium/src/+/master/docs/fuchsia_build_instructions.md' - print 'for general Fuchsia build instructions.' - return 1 - - print 'Locally building Clang %s...' % PACKAGE_VERSION - - AddCMakeToPath(args) - AddGnuWinToPath() - - DeleteChromeToolsShim() - - CheckoutRepos(args) - - if args.skip_build: - return - - cc, cxx = None, None - libstdcpp = None - - cflags = [] - cxxflags = [] - ldflags = [] - - targets = 'AArch64;ARM;Mips;PowerPC;SystemZ;WebAssembly;X86' - base_cmake_args = ['-GNinja', - '-DCMAKE_BUILD_TYPE=Release', - '-DLLVM_ENABLE_ASSERTIONS=ON', - '-DLLVM_ENABLE_TERMINFO=OFF', - '-DLLVM_TARGETS_TO_BUILD=' + targets, - # Statically link MSVCRT to avoid DLL dependencies. - '-DLLVM_USE_CRT_RELEASE=MT', - '-DCLANG_PLUGIN_SUPPORT=OFF', - '-DCLANG_ENABLE_STATIC_ANALYZER=OFF', - '-DCLANG_ENABLE_ARCMT=OFF', - ] - - if sys.platform != 'win32': - # libxml2 is required by the Win manifest merging tool used in cross-builds. - base_cmake_args.append('-DLLVM_ENABLE_LIBXML2=FORCE_ON') - - if args.bootstrap: - print 'Building bootstrap compiler' - EnsureDirExists(LLVM_BOOTSTRAP_DIR) - os.chdir(LLVM_BOOTSTRAP_DIR) - bootstrap_args = base_cmake_args + [ - '-DLLVM_TARGETS_TO_BUILD=X86;ARM;AArch64', - '-DCMAKE_INSTALL_PREFIX=' + LLVM_BOOTSTRAP_INSTALL_DIR, - '-DCMAKE_C_FLAGS=' + ' '.join(cflags), - '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags), - ] - if cc is not None: bootstrap_args.append('-DCMAKE_C_COMPILER=' + cc) - if cxx is not None: bootstrap_args.append('-DCMAKE_CXX_COMPILER=' + cxx) - RmCmakeCache('.') - RunCommand(['cmake'] + bootstrap_args + [LLVM_DIR], msvc_arch='x64') - RunCommand(['ninja'], msvc_arch='x64') - if args.run_tests: - if sys.platform == 'win32': - CopyDiaDllTo(os.path.join(LLVM_BOOTSTRAP_DIR, 'bin')) - RunCommand(['ninja', 'check-all'], msvc_arch='x64') - RunCommand(['ninja', 'install'], msvc_arch='x64') - - if sys.platform == 'win32': - cc = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang-cl.exe') - cxx = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang-cl.exe') - # CMake has a hard time with backslashes in compiler paths: - # https://stackoverflow.com/questions/13050827 - cc = cc.replace('\\', '/') - cxx = cxx.replace('\\', '/') - else: - cc = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang') - cxx = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang++') - - print 'Building final compiler' - - # LLVM uses C++11 starting in llvm 3.5. On Linux, this means libstdc++4.7+ is - # needed, on OS X it requires libc++. clang only automatically links to libc++ - # when targeting OS X 10.9+, so add stdlib=libc++ explicitly so clang can run - # on OS X versions as old as 10.7. - deployment_target = '' - - if sys.platform == 'darwin' and args.bootstrap: - # When building on 10.9, /usr/include usually doesn't exist, and while - # Xcode's clang automatically sets a sysroot, self-built clangs don't. - cflags = ['-isysroot', subprocess.check_output( - ['xcrun', '--show-sdk-path']).rstrip()] - cxxflags = ['-stdlib=libc++'] + cflags - ldflags += ['-stdlib=libc++'] - deployment_target = '10.7' - # Running libc++ tests takes a long time. Since it was only needed for - # the install step above, don't build it as part of the main build. - # This makes running package.py over 10% faster (30 min instead of 34 min) - RmTree(LIBCXX_DIR) - - - # If building at head, define a macro that plugins can use for #ifdefing - # out code that builds at head, but not at CLANG_REVISION or vice versa. - if use_head_revision: - cflags += ['-DLLVM_FORCE_HEAD_REVISION'] - cxxflags += ['-DLLVM_FORCE_HEAD_REVISION'] - - # Build PDBs for archival on Windows. Don't use RelWithDebInfo since it - # has different optimization defaults than Release. - # Also disable stack cookies (/GS-) for performance. - if sys.platform == 'win32': - cflags += ['/Zi', '/GS-'] - cxxflags += ['/Zi', '/GS-'] - ldflags += ['/DEBUG', '/OPT:REF', '/OPT:ICF'] - - deployment_env = None - if deployment_target: - deployment_env = os.environ.copy() - deployment_env['MACOSX_DEPLOYMENT_TARGET'] = deployment_target - - # Build lld and code coverage tools. This is done separately from the rest of - # the build because these tools require threading support. - tools_with_threading = [ 'lld', 'llvm-cov', 'llvm-profdata' ] - print 'Building the following tools with threading support: %s' % ( - str(tools_with_threading)) - - if os.path.exists(THREADS_ENABLED_BUILD_DIR): - RmTree(THREADS_ENABLED_BUILD_DIR) - EnsureDirExists(THREADS_ENABLED_BUILD_DIR) - os.chdir(THREADS_ENABLED_BUILD_DIR) - - threads_enabled_cmake_args = base_cmake_args + [ - '-DCMAKE_C_FLAGS=' + ' '.join(cflags), - '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags), - '-DCMAKE_EXE_LINKER_FLAGS=' + ' '.join(ldflags), - '-DCMAKE_SHARED_LINKER_FLAGS=' + ' '.join(ldflags), - '-DCMAKE_MODULE_LINKER_FLAGS=' + ' '.join(ldflags)] - if cc is not None: - threads_enabled_cmake_args.append('-DCMAKE_C_COMPILER=' + cc) - if cxx is not None: - threads_enabled_cmake_args.append('-DCMAKE_CXX_COMPILER=' + cxx) - - if args.lto_lld: - # Build lld with LTO. That speeds up the linker by ~10%. - # We only use LTO for Linux now. - # - # The linker expects all archive members to have symbol tables, so the - # archiver needs to be able to create symbol tables for bitcode files. - # GNU ar and ranlib don't understand bitcode files, but llvm-ar and - # llvm-ranlib do, so use them. - ar = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'llvm-ar') - ranlib = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'llvm-ranlib') - threads_enabled_cmake_args += [ - '-DCMAKE_AR=' + ar, - '-DCMAKE_RANLIB=' + ranlib, - '-DLLVM_ENABLE_LTO=thin', - '-DLLVM_USE_LINKER=lld'] - - RmCmakeCache('.') - RunCommand(['cmake'] + threads_enabled_cmake_args + [LLVM_DIR], - msvc_arch='x64', env=deployment_env) - RunCommand(['ninja'] + tools_with_threading, msvc_arch='x64') - - # Build clang and other tools. - CreateChromeToolsShim() - - cmake_args = [] - # TODO(thakis): Unconditionally append this to base_cmake_args instead once - # compiler-rt can build with clang-cl on Windows (http://llvm.org/PR23698) - cc_args = base_cmake_args if sys.platform != 'win32' else cmake_args - if cc is not None: cc_args.append('-DCMAKE_C_COMPILER=' + cc) - if cxx is not None: cc_args.append('-DCMAKE_CXX_COMPILER=' + cxx) - default_tools = ['plugins', 'blink_gc_plugin', 'translation_unit'] - chrome_tools = list(set(default_tools + args.extra_tools)) - cmake_args += base_cmake_args + [ - '-DLLVM_ENABLE_THREADS=OFF', - '-DCMAKE_C_FLAGS=' + ' '.join(cflags), - '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags), - '-DCMAKE_EXE_LINKER_FLAGS=' + ' '.join(ldflags), - '-DCMAKE_SHARED_LINKER_FLAGS=' + ' '.join(ldflags), - '-DCMAKE_MODULE_LINKER_FLAGS=' + ' '.join(ldflags), - '-DCMAKE_INSTALL_PREFIX=' + LLVM_BUILD_DIR, - '-DCHROMIUM_TOOLS_SRC=%s' % os.path.join(CHROMIUM_DIR, 'tools', 'clang'), - '-DCHROMIUM_TOOLS=%s' % ';'.join(chrome_tools)] - - EnsureDirExists(LLVM_BUILD_DIR) - os.chdir(LLVM_BUILD_DIR) - RmCmakeCache('.') - RunCommand(['cmake'] + cmake_args + [LLVM_DIR], - msvc_arch='x64', env=deployment_env) - RunCommand(['ninja'], msvc_arch='x64') - - # Copy in the threaded versions of lld and other tools. - if sys.platform == 'win32': - CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', 'lld-link.exe'), - os.path.join(LLVM_BUILD_DIR, 'bin')) - CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', 'lld.pdb'), - os.path.join(LLVM_BUILD_DIR, 'bin')) - else: - for tool in tools_with_threading: - CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', tool), - os.path.join(LLVM_BUILD_DIR, 'bin')) - - if chrome_tools: - # If any Chromium tools were built, install those now. - RunCommand(['ninja', 'cr-install'], msvc_arch='x64') - - VeryifyVersionOfBuiltClangMatchesVERSION() - - # Do an out-of-tree build of compiler-rt. - # On Windows, this is used to get the 32-bit ASan run-time. - # TODO(hans): Remove once the regular build above produces this. - # On Mac and Linux, this is used to get the regular 64-bit run-time. - # Do a clobbered build due to cmake changes. - if os.path.isdir(COMPILER_RT_BUILD_DIR): - RmTree(COMPILER_RT_BUILD_DIR) - os.makedirs(COMPILER_RT_BUILD_DIR) - os.chdir(COMPILER_RT_BUILD_DIR) - # TODO(thakis): Add this once compiler-rt can build with clang-cl (see - # above). - #if args.bootstrap and sys.platform == 'win32': - # The bootstrap compiler produces 64-bit binaries by default. - #cflags += ['-m32'] - #cxxflags += ['-m32'] - compiler_rt_args = base_cmake_args + [ - '-DLLVM_ENABLE_THREADS=OFF', - '-DCMAKE_C_FLAGS=' + ' '.join(cflags), - '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags)] - if sys.platform == 'darwin': - compiler_rt_args += ['-DCOMPILER_RT_ENABLE_IOS=ON'] - if sys.platform != 'win32': - compiler_rt_args += ['-DLLVM_CONFIG_PATH=' + - os.path.join(LLVM_BUILD_DIR, 'bin', 'llvm-config'), - '-DSANITIZER_MIN_OSX_VERSION="10.7"'] - # compiler-rt is part of the llvm checkout on Windows but a stand-alone - # directory elsewhere, see the TODO above COMPILER_RT_DIR. - RmCmakeCache('.') - RunCommand(['cmake'] + compiler_rt_args + - [LLVM_DIR if sys.platform == 'win32' else COMPILER_RT_DIR], - msvc_arch='x86', env=deployment_env) - RunCommand(['ninja', 'compiler-rt'], msvc_arch='x86') - if sys.platform != 'win32': - RunCommand(['ninja', 'fuzzer']) - - # Copy select output to the main tree. - # TODO(hans): Make this (and the .gypi and .isolate files) version number - # independent. - if sys.platform == 'win32': - platform = 'windows' - elif sys.platform == 'darwin': - platform = 'darwin' - else: - assert sys.platform.startswith('linux') - platform = 'linux' - rt_lib_src_dir = os.path.join(COMPILER_RT_BUILD_DIR, 'lib', platform) - if sys.platform == 'win32': - # TODO(thakis): This too is due to compiler-rt being part of the checkout - # on Windows, see TODO above COMPILER_RT_DIR. - rt_lib_src_dir = os.path.join(COMPILER_RT_BUILD_DIR, 'lib', 'clang', - VERSION, 'lib', platform) - rt_lib_dst_dir = os.path.join(LLVM_BUILD_DIR, 'lib', 'clang', VERSION, 'lib', - platform) - # Blacklists: - CopyDirectoryContents(os.path.join(rt_lib_src_dir, '..', '..', 'share'), - os.path.join(rt_lib_dst_dir, '..', '..', 'share')) - # Headers: - if sys.platform != 'win32': - CopyDirectoryContents( - os.path.join(COMPILER_RT_BUILD_DIR, 'include/sanitizer'), - os.path.join(LLVM_BUILD_DIR, 'lib/clang', VERSION, 'include/sanitizer')) - # Static and dynamic libraries: - CopyDirectoryContents(rt_lib_src_dir, rt_lib_dst_dir) - if sys.platform == 'darwin': - for dylib in glob.glob(os.path.join(rt_lib_dst_dir, '*.dylib')): - # Fix LC_ID_DYLIB for the ASan dynamic libraries to be relative to - # @executable_path. - # TODO(glider): this is transitional. We'll need to fix the dylib - # name either in our build system, or in Clang. See also - # http://crbug.com/344836. - subprocess.call(['install_name_tool', '-id', - '@executable_path/' + os.path.basename(dylib), dylib]) - - if args.with_android: - make_toolchain = os.path.join( - ANDROID_NDK_DIR, 'build', 'tools', 'make_standalone_toolchain.py') - for target_arch in ['aarch64', 'arm', 'i686']: - # Make standalone Android toolchain for target_arch. - toolchain_dir = os.path.join( - LLVM_BUILD_DIR, 'android-toolchain-' + target_arch) - api_level = '21' if target_arch == 'aarch64' else '19' - RunCommand([ - make_toolchain, - '--api=' + api_level, - '--force', - '--install-dir=%s' % toolchain_dir, - '--stl=libc++', - '--arch=' + { - 'aarch64': 'arm64', - 'arm': 'arm', - 'i686': 'x86', - }[target_arch]]) - - # NDK r16 "helpfully" installs libc++ as libstdc++ "so the compiler will - # pick it up by default". Only these days, the compiler tries to find - # libc++ instead. See https://crbug.com/902270. - shutil.copy(os.path.join(toolchain_dir, 'sysroot/usr/lib/libstdc++.a'), - os.path.join(toolchain_dir, 'sysroot/usr/lib/libc++.a')) - shutil.copy(os.path.join(toolchain_dir, 'sysroot/usr/lib/libstdc++.so'), - os.path.join(toolchain_dir, 'sysroot/usr/lib/libc++.so')) - - # Build compiler-rt runtimes needed for Android in a separate build tree. - build_dir = os.path.join(LLVM_BUILD_DIR, 'android-' + target_arch) - if not os.path.exists(build_dir): - os.mkdir(os.path.join(build_dir)) - os.chdir(build_dir) - target_triple = target_arch - abi_libs = 'c++abi' - if target_arch == 'arm': - target_triple = 'armv7' - abi_libs += ';unwind' - target_triple += '-linux-android' + api_level - cflags = ['--target=%s' % target_triple, - '--sysroot=%s/sysroot' % toolchain_dir, - '-B%s' % toolchain_dir] - android_args = base_cmake_args + [ - '-DLLVM_ENABLE_THREADS=OFF', - '-DCMAKE_C_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'), - '-DCMAKE_CXX_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang++'), - '-DLLVM_CONFIG_PATH=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-config'), - '-DCMAKE_C_FLAGS=' + ' '.join(cflags), - '-DCMAKE_CXX_FLAGS=' + ' '.join(cflags), - '-DCMAKE_ASM_FLAGS=' + ' '.join(cflags), - '-DSANITIZER_CXX_ABI=none', - '-DSANITIZER_CXX_ABI_LIBRARY=' + abi_libs, - '-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-u__cxa_demangle', - '-DANDROID=1'] - RmCmakeCache('.') - RunCommand(['cmake'] + android_args + [COMPILER_RT_DIR]) - - # We use ASan i686 build for fuzzing. - libs_want = ['lib/linux/libclang_rt.asan-{0}-android.so'] - if target_arch in ['aarch64', 'arm']: - libs_want += [ - 'lib/linux/libclang_rt.ubsan_standalone-{0}-android.so', - 'lib/linux/libclang_rt.profile-{0}-android.a', - ] - if target_arch == 'aarch64': - libs_want += ['lib/linux/libclang_rt.hwasan-{0}-android.so'] - libs_want = [lib.format(target_arch) for lib in libs_want] - RunCommand(['ninja'] + libs_want) - - # And copy them into the main build tree. - for p in libs_want: - shutil.copy(p, rt_lib_dst_dir) - - if args.with_fuchsia: - # Fuchsia links against libclang_rt.builtins-<arch>.a instead of libgcc.a. - for target_arch in ['aarch64', 'x86_64']: - fuchsia_arch_name = {'aarch64': 'arm64', 'x86_64': 'x64'}[target_arch] - toolchain_dir = os.path.join( - FUCHSIA_SDK_DIR, 'arch', fuchsia_arch_name, 'sysroot') - # Build clang_rt runtime for Fuchsia in a separate build tree. - build_dir = os.path.join(LLVM_BUILD_DIR, 'fuchsia-' + target_arch) - if not os.path.exists(build_dir): - os.mkdir(os.path.join(build_dir)) - os.chdir(build_dir) - target_spec = target_arch + '-fuchsia' - # TODO(thakis): Might have to pass -B here once sysroot contains - # binaries (e.g. gas for arm64?) - fuchsia_args = base_cmake_args + [ - '-DLLVM_ENABLE_THREADS=OFF', - '-DCMAKE_C_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'), - '-DCMAKE_CXX_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang++'), - '-DCMAKE_LINKER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'), - '-DCMAKE_AR=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-ar'), - '-DLLVM_CONFIG_PATH=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-config'), - '-DCMAKE_SYSTEM_NAME=Fuchsia', - '-DCMAKE_C_COMPILER_TARGET=%s-fuchsia' % target_arch, - '-DCMAKE_ASM_COMPILER_TARGET=%s-fuchsia' % target_arch, - '-DCOMPILER_RT_DEFAULT_TARGET_ONLY=ON', - '-DCMAKE_SYSROOT=%s' % toolchain_dir, - # TODO(thakis|scottmg): Use PER_TARGET_RUNTIME_DIR for all platforms. - # https://crbug.com/882485. - '-DLLVM_ENABLE_PER_TARGET_RUNTIME_DIR=ON', - - # These are necessary because otherwise CMake tries to build an - # executable to test to see if the compiler is working, but in doing so, - # it links against the builtins.a that we're about to build. - '-DCMAKE_C_COMPILER_WORKS=ON', - '-DCMAKE_ASM_COMPILER_WORKS=ON', - ] - RmCmakeCache('.') - RunCommand(['cmake'] + - fuchsia_args + - [os.path.join(COMPILER_RT_DIR, 'lib', 'builtins')]) - builtins_a = 'libclang_rt.builtins.a' - RunCommand(['ninja', builtins_a]) - - # And copy it into the main build tree. - fuchsia_lib_dst_dir = os.path.join(LLVM_BUILD_DIR, 'lib', 'clang', - VERSION, target_spec, 'lib') - if not os.path.exists(fuchsia_lib_dst_dir): - os.makedirs(fuchsia_lib_dst_dir) - CopyFile(os.path.join(build_dir, target_spec, 'lib', builtins_a), - fuchsia_lib_dst_dir) - - # Run tests. - if args.run_tests or use_head_revision: - os.chdir(LLVM_BUILD_DIR) - RunCommand(['ninja', 'cr-check-all'], msvc_arch='x64') - if args.run_tests: - if sys.platform == 'win32': - CopyDiaDllTo(os.path.join(LLVM_BUILD_DIR, 'bin')) - os.chdir(LLVM_BUILD_DIR) - RunCommand(['ninja', 'check-all'], msvc_arch='x64') - - WriteStampFile(PACKAGE_VERSION) - print 'Clang update was successful.' - return 0 - - -def gn_arg(v): - if v == 'True': - return True - if v == 'False': - return False - raise argparse.ArgumentTypeError('Expected one of %r or %r' % ( - 'True', 'False')) - - -def main(): - parser = argparse.ArgumentParser(description='Build Clang.') - parser.add_argument('--bootstrap', action='store_true', - help='first build clang with CC, then with itself.') - parser.add_argument('--force-local-build', action='store_true', - help="don't try to download prebuild binaries") - parser.add_argument('--gcc-toolchain', help='set the version for which gcc ' - 'version be used for building; --gcc-toolchain=/opt/foo ' - 'picks /opt/foo/bin/gcc') - parser.add_argument('--lto-lld', action='store_true', - help='build lld with LTO') - parser.add_argument('--llvm-force-head-revision', action='store_true', - help=('use the revision in the repo when printing ' - 'the revision')) - parser.add_argument('--print-revision', action='store_true', - help='print current clang revision and exit.') - parser.add_argument('--print-clang-version', action='store_true', - help='print current clang version (e.g. x.y.z) and exit.') - parser.add_argument('--run-tests', action='store_true', - help='run tests after building; only for local builds') - parser.add_argument('--skip-build', action='store_true', - help='do not build anything') - parser.add_argument('--skip-checkout', action='store_true', - help='do not create or update any checkouts') - parser.add_argument('--extra-tools', nargs='*', default=[], - help='select additional chrome tools to build') - parser.add_argument('--use-system-cmake', action='store_true', - help='use the cmake from PATH instead of downloading ' - 'and using prebuilt cmake binaries') - parser.add_argument('--verify-version', - help='verify that clang has the passed-in version') - parser.add_argument('--with-android', type=gn_arg, nargs='?', const=True, - help='build the Android ASan runtime (linux only)', - default=sys.platform.startswith('linux')) - parser.add_argument('--without-android', action='store_false', - help='don\'t build Android ASan runtime (linux only)', - dest='with_android') - parser.add_argument('--without-fuchsia', action='store_false', - help='don\'t build Fuchsia clang_rt runtime (linux/mac)', - dest='with_fuchsia', - default=sys.platform in ('linux2', 'darwin')) - args = parser.parse_args() - - if args.lto_lld and not args.bootstrap: - print '--lto-lld requires --bootstrap' - return 1 - if args.lto_lld and not sys.platform.startswith('linux'): - print '--lto-lld is only effective on Linux. Ignoring the option.' - args.lto_lld = False - - # Get svn if we're going to use it to check the revision or do a local build. - if (use_head_revision or args.llvm_force_head_revision or - args.force_local_build): - AddSvnToPathOnWin() - - if args.verify_version and args.verify_version != VERSION: - print 'VERSION is %s but --verify-version argument was %s, exiting.' % ( - VERSION, args.verify_version) - print 'clang_version in build/toolchain/toolchain.gni is likely outdated.' - return 1 - - global CLANG_REVISION, PACKAGE_VERSION - if args.print_revision: - if use_head_revision or args.llvm_force_head_revision: - print GetSvnRevision(LLVM_DIR) - else: - print PACKAGE_VERSION - return 0 - - if args.print_clang_version: - sys.stdout.write(VERSION) - return 0 - - # Don't buffer stdout, so that print statements are immediately flushed. - # Do this only after --print-revision has been handled, else we'll get - # an error message when this script is run from gn for some reason. - sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0) - - if use_head_revision: - # Use a real revision number rather than HEAD to make sure that the stamp - # file logic works. - CLANG_REVISION = GetSvnRevision(LLVM_REPO_URL) - PACKAGE_VERSION = CLANG_REVISION + '-0' - - args.force_local_build = True - # Don't build fuchsia runtime on ToT bots at all. - args.with_fuchsia = False - - return UpdateClang(args) - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/chromium/third_party/openscreen/src/tools/download-clang-update-script.py b/chromium/third_party/openscreen/src/tools/download-clang-update-script.py new file mode 100755 index 00000000000..0d707060920 --- /dev/null +++ b/chromium/third_party/openscreen/src/tools/download-clang-update-script.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""This script is used to download the clang update script. It runs as a +gclient hook. + +It's equivalent to using curl to download the latest update script: + + $ curl --silent --create-dirs -o tools/clang/scripts/update.py \ + https://raw.githubusercontent.com/chromium/chromium/master/tools/clang/scripts/update.py + +The purpose of "reinventing the wheel" with this script is just so developers +aren't required to have curl installed. +""" + +import argparse +import os +import sys + +try: + from urllib2 import HTTPError, URLError, urlopen +except ImportError: # For Py3 compatibility + from urllib.error import HTTPError, URLError + from urllib.request import urlopen + +SCRIPT_DOWNLOAD_URL = ('https://raw.githubusercontent.com/' + + 'chromium/chromium/master/tools/clang/scripts/update.py') + +def main(): + parser = argparse.ArgumentParser( + description='Download clang update script from chromium master.') + parser.add_argument('--output', + help='Path to script file to create/overwrite.') + args = parser.parse_args() + + if not args.output: + print('usage: download-clang-update-script.py ' + + '--output=tools/clang/scripts/update.py'); + return 1 + + script_contents = '' + try: + response = urlopen(SCRIPT_DOWNLOAD_URL) + script_contents = response.read() + except HTTPError as e: + print e.code + print e.read() + return 1 + except URLError as e: + print 'Download failed. Reason: ', e.reason + return 1 + + directory = os.path.dirname(args.output) + if not os.path.exists(directory): + os.makedirs(directory) + + script_file = open(args.output, 'w') + script_file.write(script_contents) + script_file.close() + + return 0 + +if __name__ == '__main__': + sys.exit(main()) diff --git a/chromium/third_party/openscreen/src/tools/install-build-tools.sh b/chromium/third_party/openscreen/src/tools/install-build-tools.sh deleted file mode 100755 index 65fb4643969..00000000000 --- a/chromium/third_party/openscreen/src/tools/install-build-tools.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2018 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -uname="$(uname -s)" -case "${uname}" in - Linux*) env="linux64";; - Darwin*) env="mac";; -esac - -echo "Assuming we are running in $env..." - -ninja_version="v1.9.0" -ninja_zipfile="" -case "$env" in - linux64) ninja_zipfile="ninja-linux.zip";; - mac) ninja_zipfile="ninja-mac.zip";; -esac - -GOOGLE_STORAGE_URL="https://storage.googleapis.com" -BUILDTOOLS_ROOT=$(git rev-parse --show-toplevel)/buildtools/$env -if [ ! -d $BUILDTOOLS_ROOT ]; then - mkdir -p $BUILDTOOLS_ROOT -fi - -pushd $BUILDTOOLS_ROOT -set -x # echo on -sha1=$(tail -c+1 $BUILDTOOLS_ROOT/clang-format.sha1) -curl -Lo clang-format "$GOOGLE_STORAGE_URL/chromium-clang-format/$sha1" -chmod +x clang-format -curl -L "https://github.com/ninja-build/ninja/releases/download/${ninja_version}/${ninja_zipfile}" | funzip > ninja -chmod +x ninja -set +x # echo off -popd - diff --git a/chromium/third_party/openscreen/src/util/BUILD.gn b/chromium/third_party/openscreen/src/util/BUILD.gn index 8f1958e2a4f..12cc00a03bf 100644 --- a/chromium/third_party/openscreen/src/util/BUILD.gn +++ b/chromium/third_party/openscreen/src/util/BUILD.gn @@ -12,6 +12,8 @@ source_set("util") { "big_endian.h", "crypto/certificate_utils.cc", "crypto/certificate_utils.h", + "crypto/digest_sign.cc", + "crypto/digest_sign.h", "crypto/openssl_util.cc", "crypto/openssl_util.h", "crypto/rsa_private_key.cc", @@ -20,30 +22,38 @@ source_set("util") { "crypto/secure_hash.h", "crypto/sha2.cc", "crypto/sha2.h", + "hashing.h", "integer_division.h", - "json/json_reader.cc", - "json/json_reader.h", - "json/json_writer.cc", - "json/json_writer.h", + "json/json_serialization.cc", + "json/json_serialization.h", + "json/json_value.cc", + "json/json_value.h", "logging.h", "operation_loop.cc", "operation_loop.h", "saturate_cast.h", "serial_delete_ptr.h", + "simple_fraction.cc", + "simple_fraction.h", "std_util.h", + "stringprintf.cc", "stringprintf.h", "trace_logging.h", "trace_logging/macro_support.h", "trace_logging/scoped_trace_operations.cc", "trace_logging/scoped_trace_operations.h", + "weak_ptr.h", "yet_another_bit_vector.cc", "yet_another_bit_vector.h", ] + public_deps = [ + "../third_party/jsoncpp", + ] + deps = [ "../third_party/abseil", "../third_party/boringssl", - "../third_party/jsoncpp", ] public_configs = [ "../build:openscreen_include_dirs" ] @@ -60,13 +70,16 @@ source_set("unittests") { "crypto/secure_hash_unittest.cc", "crypto/sha2_unittest.cc", "integer_division_unittest.cc", - "json/json_reader_unittest.cc", - "json/json_writer_unittest.cc", + "json/json_serialization_unittest.cc", + "json/json_value_unittest.cc", "operation_loop_unittest.cc", "saturate_cast_unittest.cc", "serial_delete_ptr_unittest.cc", + "simple_fraction_unittest.cc", + "stringprintf_unittest.cc", "trace_logging/scoped_trace_operations_unittest.cc", "trace_logging_unittest.cc", + "weak_ptr_unittest.cc", "yet_another_bit_vector_unittest.cc", ] diff --git a/chromium/third_party/openscreen/src/util/alarm.cc b/chromium/third_party/openscreen/src/util/alarm.cc index c72815a19c5..a8427285a9b 100644 --- a/chromium/third_party/openscreen/src/util/alarm.cc +++ b/chromium/third_party/openscreen/src/util/alarm.cc @@ -58,8 +58,7 @@ class Alarm::CancelableFunctor { Alarm* alarm_; }; -Alarm::Alarm(platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner) +Alarm::Alarm(ClockNowFunctionPtr now_function, TaskRunner* task_runner) : now_function_(now_function), task_runner_(task_runner) { OSP_DCHECK(now_function_); OSP_DCHECK(task_runner_); @@ -73,29 +72,30 @@ Alarm::~Alarm() { } void Alarm::Cancel() { - scheduled_task_ = platform::TaskRunner::Task(); + scheduled_task_ = TaskRunner::Task(); } -void Alarm::ScheduleWithTask(platform::TaskRunner::Task task, - platform::Clock::time_point alarm_time) { +void Alarm::ScheduleWithTask(TaskRunner::Task task, + Clock::time_point desired_alarm_time) { OSP_DCHECK(task.valid()); scheduled_task_ = std::move(task); - alarm_time_ = alarm_time; + + const Clock::time_point now = now_function_(); + alarm_time_ = std::max(now, desired_alarm_time); // Ensure that a later firing will occur, and not too late. if (queued_fire_) { - if (next_fire_time_ <= alarm_time) { + if (next_fire_time_ <= alarm_time_) { return; } queued_fire_->Cancel(); OSP_DCHECK(!queued_fire_); } - InvokeLater(now_function_(), alarm_time); + InvokeLater(now, alarm_time_); } -void Alarm::InvokeLater(platform::Clock::time_point now, - platform::Clock::time_point fire_time) { +void Alarm::InvokeLater(Clock::time_point now, Clock::time_point fire_time) { OSP_DCHECK(!queued_fire_); next_fire_time_ = fire_time; // Note: Instantiating the CancelableFunctor below sets |this->queued_fire_|. @@ -109,7 +109,7 @@ void Alarm::TryInvoke() { // If this is an early firing, re-schedule for later. This happens if // Schedule() was called again before this firing had occurred. - const platform::Clock::time_point now = now_function_(); + const Clock::time_point now = now_function_(); if (now < alarm_time_) { InvokeLater(now, alarm_time_); return; @@ -119,8 +119,11 @@ void Alarm::TryInvoke() { // itself: a) calls any Alarm methods re-entrantly, or b) causes the // destruction of this Alarm instance. // WARNING: |this| is not valid after here! - platform::TaskRunner::Task task = std::move(scheduled_task_); + TaskRunner::Task task = std::move(scheduled_task_); task(); } +// static +constexpr Clock::time_point Alarm::kImmediately; + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/alarm.h b/chromium/third_party/openscreen/src/util/alarm.h index dbf75714a56..551c0b76217 100644 --- a/chromium/third_party/openscreen/src/util/alarm.h +++ b/chromium/third_party/openscreen/src/util/alarm.h @@ -31,8 +31,7 @@ namespace openscreen { // running the client's Task later; or c) runs the client's Task. class Alarm { public: - Alarm(platform::ClockNowFunctionPtr now_function, - platform::TaskRunner* task_runner); + Alarm(ClockNowFunctionPtr now_function, TaskRunner* task_runner); ~Alarm(); // The design requires that Alarm instances not be copied or moved. @@ -44,20 +43,18 @@ class Alarm { // Schedule the |functor| to be invoked at |alarm_time|. If this Alarm was // already scheduled, the prior scheduling is canceled. The Functor can be any // callable target (e.g., function, lambda-expression, std::bind result, - // etc.). + // etc.). If |alarm_time| is on or before "now," such as kImmediately, it is + // scheduled to run as soon as possible. template <typename Functor> - inline void Schedule(Functor functor, - platform::Clock::time_point alarm_time) { - ScheduleWithTask(platform::TaskRunner::Task(std::move(functor)), - alarm_time); + inline void Schedule(Functor functor, Clock::time_point alarm_time) { + ScheduleWithTask(TaskRunner::Task(std::move(functor)), alarm_time); } // Same as Schedule(), but invoke the functor at the given |delay| after right // now. template <typename Functor> - inline void ScheduleFromNow(Functor functor, - platform::Clock::duration delay) { - ScheduleWithTask(platform::TaskRunner::Task(std::move(functor)), + inline void ScheduleFromNow(Functor functor, Clock::duration delay) { + ScheduleWithTask(TaskRunner::Task(std::move(functor)), now_function_() + delay); } @@ -67,8 +64,10 @@ class Alarm { // See comments for Schedule(). Generally, callers will want to call // Schedule() instead of this, for more-convenient caller-side syntax, unless // they already have a Task to pass-in. - void ScheduleWithTask(platform::TaskRunner::Task task, - platform::Clock::time_point alarm_time); + void ScheduleWithTask(TaskRunner::Task task, Clock::time_point alarm_time); + + // A special time_point value representing "as soon as possible." + static constexpr Clock::time_point kImmediately = Clock::time_point::min(); private: // A move-only functor that holds a raw pointer back to |this| and can be @@ -77,20 +76,19 @@ class Alarm { class CancelableFunctor; // Posts a delayed call to TryInvoke() to the TaskRunner. - void InvokeLater(platform::Clock::time_point now, - platform::Clock::time_point fire_time); + void InvokeLater(Clock::time_point now, Clock::time_point fire_time); // Examines whether to invoke the client's Task now; or try again later; or // just do nothing. See class-level design comments. void TryInvoke(); - const platform::ClockNowFunctionPtr now_function_; - platform::TaskRunner* const task_runner_; + const ClockNowFunctionPtr now_function_; + TaskRunner* const task_runner_; // This is the task the client wants to have run at a specific point-in-time. // This is NOT the task that Alarm provides to the TaskRunner. - platform::TaskRunner::Task scheduled_task_; - platform::Clock::time_point alarm_time_{}; + TaskRunner::Task scheduled_task_; + Clock::time_point alarm_time_{}; // When non-null, there is a task in the TaskRunner's queue that will call // TryInvoke() some time in the future. This member is exclusively maintained @@ -99,7 +97,7 @@ class Alarm { // When the CancelableFunctor is scheduled to run. It may possibly execute // later than this, if the TaskRunner is falling behind. - platform::Clock::time_point next_fire_time_{}; + Clock::time_point next_fire_time_{}; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/alarm_unittest.cc b/chromium/third_party/openscreen/src/util/alarm_unittest.cc index 71f570a058a..5fad74bf7f4 100644 --- a/chromium/third_party/openscreen/src/util/alarm_unittest.cc +++ b/chromium/third_party/openscreen/src/util/alarm_unittest.cc @@ -15,30 +15,28 @@ namespace { class AlarmTest : public testing::Test { public: - platform::FakeClock* clock() { return &clock_; } - platform::TaskRunner* task_runner() { return &task_runner_; } + FakeClock* clock() { return &clock_; } + FakeTaskRunner* task_runner() { return &task_runner_; } Alarm* alarm() { return &alarm_; } private: - platform::FakeClock clock_{platform::Clock::now()}; - platform::FakeTaskRunner task_runner_{&clock_}; - Alarm alarm_{&platform::FakeClock::now, &task_runner_}; + FakeClock clock_{Clock::now()}; + FakeTaskRunner task_runner_{&clock_}; + Alarm alarm_{&FakeClock::now, &task_runner_}; }; TEST_F(AlarmTest, RunsTaskAsClockAdvances) { - constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::duration kDelay = std::chrono::milliseconds(20); - const platform::Clock::time_point alarm_time = - platform::FakeClock::now() + kDelay; - platform::Clock::time_point actual_run_time{}; - alarm()->Schedule([&]() { actual_run_time = platform::FakeClock::now(); }, - alarm_time); + const Clock::time_point alarm_time = FakeClock::now() + kDelay; + Clock::time_point actual_run_time{}; + alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); }, alarm_time); // Confirm the lambda did not run immediately. - ASSERT_EQ(platform::Clock::time_point{}, actual_run_time); + ASSERT_EQ(Clock::time_point{}, actual_run_time); // Confirm the lambda does not run until the necessary delay has elapsed. clock()->Advance(kDelay / 2); - ASSERT_EQ(platform::Clock::time_point{}, actual_run_time); + ASSERT_EQ(Clock::time_point{}, actual_run_time); // Confirm the lambda is called when the necessary delay has elapsed. clock()->Advance(kDelay / 2); @@ -49,17 +47,34 @@ TEST_F(AlarmTest, RunsTaskAsClockAdvances) { ASSERT_EQ(alarm_time, actual_run_time); } +TEST_F(AlarmTest, RunsTaskImmediately) { + const Clock::time_point expected_run_time = FakeClock::now(); + Clock::time_point actual_run_time{}; + alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); }, + Alarm::kImmediately); + // Confirm the lambda did not run yet, since it should run asynchronously, in + // a separate TaskRunner task. + ASSERT_EQ(Clock::time_point{}, actual_run_time); + + // Confirm the lambda runs without the clock having to tick forward. + task_runner()->RunTasksUntilIdle(); + ASSERT_EQ(expected_run_time, actual_run_time); + + // Confirm the lambda is only run once. + clock()->Advance(std::chrono::seconds(2)); + ASSERT_EQ(expected_run_time, actual_run_time); +} + TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) { - constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20); - constexpr platform::Clock::time_point kNever{}; + constexpr Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::time_point kNever{}; - platform::Clock::time_point actual_run_time{}; + Clock::time_point actual_run_time{}; { - Alarm scoped_alarm(&platform::FakeClock::now, task_runner()); - const platform::Clock::time_point alarm_time = - platform::FakeClock::now() + kDelay; - scoped_alarm.Schedule( - [&]() { actual_run_time = platform::FakeClock::now(); }, alarm_time); + Alarm scoped_alarm(&FakeClock::now, task_runner()); + const Clock::time_point alarm_time = FakeClock::now() + kDelay; + scoped_alarm.Schedule([&]() { actual_run_time = FakeClock::now(); }, + alarm_time); // |scoped_alarm| is destroyed. } @@ -70,31 +85,27 @@ TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) { } TEST_F(AlarmTest, Cancels) { - constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::duration kDelay = std::chrono::milliseconds(20); - const platform::Clock::time_point alarm_time = - platform::FakeClock::now() + kDelay; - platform::Clock::time_point actual_run_time{}; - alarm()->Schedule([&]() { actual_run_time = platform::FakeClock::now(); }, - alarm_time); + const Clock::time_point alarm_time = FakeClock::now() + kDelay; + Clock::time_point actual_run_time{}; + alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); }, alarm_time); // Advance the clock for half the delay, and confirm the lambda has not run // yet. clock()->Advance(kDelay / 2); - ASSERT_EQ(platform::Clock::time_point{}, actual_run_time); + ASSERT_EQ(Clock::time_point{}, actual_run_time); // Cancel and then advance the clock well past the delay, and confirm the // lambda has never run. alarm()->Cancel(); clock()->Advance(kDelay * 100); - ASSERT_EQ(platform::Clock::time_point{}, actual_run_time); + ASSERT_EQ(Clock::time_point{}, actual_run_time); } TEST_F(AlarmTest, CancelsAndRearms) { - constexpr platform::Clock::duration kShorterDelay = - std::chrono::milliseconds(10); - constexpr platform::Clock::duration kLongerDelay = - std::chrono::milliseconds(100); + constexpr Clock::duration kShorterDelay = std::chrono::milliseconds(10); + constexpr Clock::duration kLongerDelay = std::chrono::milliseconds(100); // Run the test twice: Once when scheduling first with a long delay, then a // shorter delay; and once when scheduling first with a short delay, then a @@ -105,7 +116,7 @@ TEST_F(AlarmTest, CancelsAndRearms) { const auto delay2 = do_longer_then_shorter ? kShorterDelay : kLongerDelay; int count1 = 0; - alarm()->Schedule([&]() { ++count1; }, platform::FakeClock::now() + delay1); + alarm()->Schedule([&]() { ++count1; }, FakeClock::now() + delay1); // Advance the clock for half of |delay1|, and confirm the lambda that // increments the variable does not run. @@ -116,7 +127,7 @@ TEST_F(AlarmTest, CancelsAndRearms) { // Schedule a different lambda, that increments a different variable, to run // after |delay2|. int count2 = 0; - alarm()->Schedule([&]() { ++count2; }, platform::FakeClock::now() + delay2); + alarm()->Schedule([&]() { ++count2; }, FakeClock::now() + delay2); // Confirm the second scheduling will fire at the right moment. clock()->Advance(delay2 / 2); diff --git a/chromium/third_party/openscreen/src/util/big_endian.h b/chromium/third_party/openscreen/src/util/big_endian.h index 6c94ca5e6a6..b2067d7537e 100644 --- a/chromium/third_party/openscreen/src/util/big_endian.h +++ b/chromium/third_party/openscreen/src/util/big_endian.h @@ -26,43 +26,72 @@ inline bool IsBigEndianArchitecture() { return !!bytes[0]; } -// Returns the bytes of |x| in reverse order. This is only defined for 16-, 32-, -// and 64-bit unsigned integers. -template <typename Integer> -Integer ByteSwap(Integer x); +namespace internal { + +template <int size> +struct MakeSizedUnsignedInteger; + +template <> +struct MakeSizedUnsignedInteger<1> { + using type = uint8_t; +}; + +template <> +struct MakeSizedUnsignedInteger<2> { + using type = uint16_t; +}; + +template <> +struct MakeSizedUnsignedInteger<4> { + using type = uint32_t; +}; + +template <> +struct MakeSizedUnsignedInteger<8> { + using type = uint64_t; +}; + +template <int size> +inline typename MakeSizedUnsignedInteger<size>::type ByteSwap( + typename MakeSizedUnsignedInteger<size>::type x) { + static_assert(size <= 8, + "ByteSwap() specialization missing in " __FILE__ + ". " + "Are you trying to use an integer larger than 64 bits?"); +} template <> -inline uint8_t ByteSwap(uint8_t x) { +inline uint8_t ByteSwap<1>(uint8_t x) { return x; } #if defined(__clang__) || defined(__GNUC__) template <> -inline uint64_t ByteSwap(uint64_t x) { +inline uint64_t ByteSwap<8>(uint64_t x) { return __builtin_bswap64(x); } template <> -inline uint32_t ByteSwap(uint32_t x) { +inline uint32_t ByteSwap<4>(uint32_t x) { return __builtin_bswap32(x); } template <> -inline uint16_t ByteSwap(uint16_t x) { +inline uint16_t ByteSwap<2>(uint16_t x) { return __builtin_bswap16(x); } #elif defined(_MSC_VER) template <> -inline uint64_t ByteSwap(uint64_t x) { +inline uint64_t ByteSwap<8>(uint64_t x) { return _byteswap_uint64(x); } template <> -inline uint32_t ByteSwap(uint32_t x) { +inline uint32_t ByteSwap<4>(uint32_t x) { return _byteswap_ulong(x); } template <> -inline uint16_t ByteSwap(uint16_t x) { +inline uint16_t ByteSwap<2>(uint16_t x) { return _byteswap_ushort(x); } @@ -71,20 +100,30 @@ inline uint16_t ByteSwap(uint16_t x) { #include <byteswap.h> template <> -inline uint64_t ByteSwap(uint64_t x) { +inline uint64_t ByteSwap<8>(uint64_t x) { return bswap_64(x); } template <> -inline uint32_t ByteSwap(uint32_t x) { +inline uint32_t ByteSwap<4>(uint32_t x) { return bswap_32(x); } template <> -inline uint16_t ByteSwap(uint16_t x) { +inline uint16_t ByteSwap<2>(uint16_t x) { return bswap_16(x); } #endif +} // namespace internal + +// Returns the bytes of |x| in reverse order. This is only defined for 16-, 32-, +// and 64-bit unsigned integers. +template <typename Integer> +inline std::enable_if_t<std::is_unsigned<Integer>::value, Integer> ByteSwap( + Integer x) { + return internal::ByteSwap<sizeof(Integer)>(x); +} + // Read a POD integer from |src| in big-endian byte order, returning the integer // in native byte order. template <typename Integer> diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc index 1d6873f4485..844e60554b8 100644 --- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc +++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc @@ -6,22 +6,31 @@ #include <openssl/asn1.h> #include <openssl/bio.h> +#include <openssl/bn.h> #include <openssl/crypto.h> #include <openssl/evp.h> #include <openssl/rsa.h> #include <openssl/ssl.h> +#include <openssl/x509v3.h> #include <time.h> -#include <atomic> #include <string> #include "util/crypto/openssl_util.h" #include "util/crypto/sha2.h" +#include "util/logging.h" namespace openscreen { namespace { +// These values are bit positions from RFC 5280 4.2.1.3 and will be passed to +// ASN1_BIT_STRING_set_bit. +enum KeyUsageBits { + kDigitalSignature = 0, + kKeyCertSign = 5, +}; + // Returns whether or not the certificate field successfully was added. bool AddCertificateField(X509_NAME* certificate_name, absl::string_view field, @@ -41,12 +50,22 @@ bssl::UniquePtr<X509> CreateCertificateInternal( absl::string_view name, std::chrono::seconds certificate_duration, EVP_PKEY key_pair, - std::chrono::seconds time_since_unix_epoch) { + std::chrono::seconds time_since_unix_epoch, + bool make_ca, + X509* issuer, + EVP_PKEY* issuer_key) { + OSP_DCHECK((!!issuer) == (!!issuer_key)); bssl::UniquePtr<X509> certificate(X509_new()); + if (!issuer) { + issuer = certificate.get(); + } + if (!issuer_key) { + issuer_key = &key_pair; + } // Serial numbers must be unique for this session. As a pretend CA, we should // not issue certificates with the same serial number in the same session. - static std::atomic_int serial_number(1); + static int serial_number(1); if (ASN1_INTEGER_set(X509_get_serialNumber(certificate.get()), serial_number++) != 1) { return nullptr; @@ -66,12 +85,36 @@ bssl::UniquePtr<X509> CreateCertificateInternal( return nullptr; } - if ((X509_set_issuer_name(certificate.get(), certificate_name) != 1) || + bssl::UniquePtr<ASN1_BIT_STRING> x(ASN1_BIT_STRING_new()); + ASN1_BIT_STRING_set_bit(x.get(), KeyUsageBits::kDigitalSignature, 1); + if (make_ca) { + ASN1_BIT_STRING_set_bit(x.get(), KeyUsageBits::kKeyCertSign, 1); + } + if (X509_add1_ext_i2d(certificate.get(), NID_key_usage, x.get(), 0, 0) != 1) { + return nullptr; + } + if (make_ca) { + X509V3_CTX ctx; + X509V3_set_ctx_nodb(&ctx); + X509V3_set_ctx(&ctx, issuer, certificate.get(), nullptr, nullptr, 0); + bssl::UniquePtr<X509_EXTENSION> ex( + X509V3_EXT_nconf_nid(nullptr, &ctx, NID_basic_constraints, + const_cast<char*>("critical,CA:TRUE"))); + if (!ex) { + return nullptr; + } + void* thing = X509V3_EXT_d2i(ex.get()); + X509_add1_ext_i2d(certificate.get(), NID_basic_constraints, thing, 1, 0); + X509V3_EXT_free(NID_basic_constraints, thing); + } + + X509_NAME* issuer_name = X509_get_subject_name(issuer); + if ((X509_set_issuer_name(certificate.get(), issuer_name) != 1) || (X509_set_pubkey(certificate.get(), &key_pair) != 1) || // Unlike all of the other BoringSSL methods here, X509_sign returns // the size of the signature in bytes. - (X509_sign(certificate.get(), &key_pair, EVP_sha256()) <= 0) || - (X509_verify(certificate.get(), &key_pair) != 1)) { + (X509_sign(certificate.get(), issuer_key, EVP_sha256()) <= 0) || + (X509_verify(certificate.get(), issuer_key) != 1)) { return nullptr; } @@ -80,20 +123,57 @@ bssl::UniquePtr<X509> CreateCertificateInternal( } // namespace -ErrorOr<bssl::UniquePtr<X509>> CreateCertificate( +bssl::UniquePtr<EVP_PKEY> GenerateRsaKeyPair(int key_bits) { + bssl::UniquePtr<BIGNUM> prime(BN_new()); + if (BN_set_word(prime.get(), RSA_F4) == 0) { + return nullptr; + } + + bssl::UniquePtr<RSA> rsa(RSA_new()); + if (RSA_generate_key_ex(rsa.get(), key_bits, prime.get(), nullptr) == 0) { + return nullptr; + } + + bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new()); + if (EVP_PKEY_set1_RSA(pkey.get(), rsa.get()) == 0) { + return nullptr; + } + + return pkey; +} + +ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate( absl::string_view name, std::chrono::seconds duration, const EVP_PKEY& key_pair, std::chrono::seconds time_since_unix_epoch) { bssl::UniquePtr<X509> certificate = CreateCertificateInternal( - name, duration, key_pair, time_since_unix_epoch); + name, duration, key_pair, time_since_unix_epoch, false, nullptr, nullptr); if (!certificate) { return Error::Code::kCertificateCreationError; } return certificate; } -ErrorOr<std::vector<uint8_t>> ExportCertificate(const X509& certificate) { +ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest( + absl::string_view name, + std::chrono::seconds duration, + const EVP_PKEY& key_pair, + std::chrono::seconds time_since_unix_epoch, + bool make_ca, + X509* issuer, + EVP_PKEY* issuer_key) { + bssl::UniquePtr<X509> certificate = + CreateCertificateInternal(name, duration, key_pair, time_since_unix_epoch, + make_ca, issuer, issuer_key); + if (!certificate) { + return Error::Code::kCertificateCreationError; + } + return certificate; +} + +ErrorOr<std::vector<uint8_t>> ExportX509CertificateToDer( + const X509& certificate) { unsigned char* buffer = nullptr; // Casting-away the const because the legacy i2d_X509() function is not // const-correct. @@ -121,4 +201,44 @@ ErrorOr<bssl::UniquePtr<X509>> ImportCertificate(const uint8_t* der_x509_cert, return certificate; } +ErrorOr<bssl::UniquePtr<EVP_PKEY>> ImportRSAPrivateKey( + const uint8_t* der_rsa_private_key, + int key_length) { + if (!der_rsa_private_key || key_length == 0) { + return Error::Code::kParameterInvalid; + } + + RSA* rsa = RSA_private_key_from_bytes(der_rsa_private_key, key_length); + if (!rsa) { + return Error::Code::kRSAKeyParseError; + } + bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new()); + EVP_PKEY_assign_RSA(pkey.get(), rsa); + return pkey; +} + +std::string GetSpkiTlv(X509* cert) { + int len = i2d_X509_PUBKEY(cert->cert_info->key, nullptr); + if (len <= 0) { + return {}; + } + std::string x(len, 0); + uint8_t* data = reinterpret_cast<uint8_t*>(&x[0]); + if (!i2d_X509_PUBKEY(cert->cert_info->key, &data)) { + return {}; + } + return x; +} + +ErrorOr<uint64_t> ParseDerUint64(ASN1_INTEGER* asn1int) { + if (asn1int->length > 8 || asn1int->length == 0) { + return Error::Code::kParameterInvalid; + } + uint64_t result = 0; + for (int i = 0; i < asn1int->length; ++i) { + result = (result << 8) | asn1int->data[i]; + } + return result; +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h index 4bfe6e37a61..e60c28c3d26 100644 --- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h +++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h @@ -5,10 +5,12 @@ #ifndef UTIL_CRYPTO_CERTIFICATE_UTILS_H_ #define UTIL_CRYPTO_CERTIFICATE_UTILS_H_ +#include <openssl/evp.h> #include <openssl/x509.h> #include <stdint.h> #include <chrono> +#include <string> #include <vector> #include "absl/strings/string_view.h" @@ -18,23 +20,50 @@ namespace openscreen { +// Generates a new RSA key pair with bit width |key_bits|. +bssl::UniquePtr<EVP_PKEY> GenerateRsaKeyPair(int key_bits = 2048); + // Creates a new self-signed X509 certificate having the given |name| and -// |duration| until expiration, and based on the given |key_pair|. +// |duration| until expiration, and based on the given |key_pair|, which is +// expected to contain a valid private key. // |time_since_unix_epoch| is the current time. -ErrorOr<bssl::UniquePtr<X509>> CreateCertificate( +ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate( + absl::string_view name, + std::chrono::seconds duration, + const EVP_PKEY& key_pair, + std::chrono::seconds time_since_unix_epoch = GetWallTimeSinceUnixEpoch()); + +// Creates a new X509 certificate having the given |name| and |duration| until +// expiration, and based on the given |key_pair|. If |issuer| and |issuer_key| +// are provided, they are used to set the issuer information, otherwise it will +// be self-signed. |make_ca| determines whether additional extensions are added +// to make it a valid certificate authority cert. +ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest( absl::string_view name, std::chrono::seconds duration, const EVP_PKEY& key_pair, - std::chrono::seconds time_since_unix_epoch = - platform::GetWallTimeSinceUnixEpoch()); + std::chrono::seconds time_since_unix_epoch = GetWallTimeSinceUnixEpoch(), + bool make_ca = false, + X509* issuer = nullptr, + EVP_PKEY* issuer_key = nullptr); // Exports the given X509 certificate as its DER-encoded binary form. -ErrorOr<std::vector<uint8_t>> ExportCertificate(const X509& certificate); +ErrorOr<std::vector<uint8_t>> ExportX509CertificateToDer( + const X509& certificate); // Parses a DER-encoded X509 certificate from its binary form. ErrorOr<bssl::UniquePtr<X509>> ImportCertificate(const uint8_t* der_x509_cert, int der_x509_cert_length); +// Parses a DER-encoded RSAPrivateKey (RFC 3447). +ErrorOr<bssl::UniquePtr<EVP_PKEY>> ImportRSAPrivateKey( + const uint8_t* der_rsa_private_key, + int key_length); + +std::string GetSpkiTlv(X509* cert); + +ErrorOr<uint64_t> ParseDerUint64(ASN1_INTEGER* asn1int); + } // namespace openscreen #endif // UTIL_CRYPTO_CERTIFICATE_UTILS_H_ diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc b/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc index 7475756bce4..91f6f9cbd24 100644 --- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc +++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc @@ -22,24 +22,12 @@ namespace { constexpr char kName[] = "test.com"; constexpr auto kDuration = std::chrono::seconds(31556952); -bssl::UniquePtr<EVP_PKEY> GenerateRsaKeypair() { - bssl::UniquePtr<BIGNUM> prime(BN_new()); - EXPECT_NE(0, BN_set_word(prime.get(), RSA_F4)); - - bssl::UniquePtr<RSA> rsa(RSA_new()); - EXPECT_NE(0, RSA_generate_key_ex(rsa.get(), 2048, prime.get(), nullptr)); - - bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new()); - EXPECT_NE(0, EVP_PKEY_set1_RSA(pkey.get(), rsa.get())); - - return pkey; -} - TEST(CertificateUtilTest, CreatesValidCertificate) { - bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeypair(); + bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeyPair(); + ASSERT_TRUE(pkey); ErrorOr<bssl::UniquePtr<X509>> certificate = - CreateCertificate(kName, kDuration, *pkey); + CreateSelfSignedX509Certificate(kName, kDuration, *pkey); ASSERT_TRUE(certificate.is_value()); // Validate the generated certificate. @@ -47,13 +35,14 @@ TEST(CertificateUtilTest, CreatesValidCertificate) { } TEST(CertificateUtilTest, ExportsAndImportsCertificate) { - bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeypair(); + bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeyPair(); + ASSERT_TRUE(pkey); ErrorOr<bssl::UniquePtr<X509>> certificate = - CreateCertificate(kName, kDuration, *pkey); + CreateSelfSignedX509Certificate(kName, kDuration, *pkey); ASSERT_TRUE(certificate.is_value()); ErrorOr<std::vector<uint8_t>> exported = - ExportCertificate(*certificate.value()); + ExportX509CertificateToDer(*certificate.value()); ASSERT_TRUE(exported.is_value()) << exported.error(); EXPECT_FALSE(exported.value().empty()); diff --git a/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc b/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc new file mode 100644 index 00000000000..fa1ab24dfaa --- /dev/null +++ b/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc @@ -0,0 +1,31 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/crypto/digest_sign.h" + +namespace openscreen { + +ErrorOr<std::string> SignData(const EVP_MD* digest, + EVP_PKEY* private_key, + absl::Span<const uint8_t> data) { + bssl::ScopedEVP_MD_CTX ctx; + if (!EVP_DigestSignInit(ctx.get(), nullptr, digest, nullptr, private_key)) { + return Error::Code::kEVPInitializationError; + } + size_t signature_length = 0; + if ((EVP_DigestSign(ctx.get(), nullptr, &signature_length, data.data(), + data.size()) != 1) || + signature_length == 0) { + return Error::Code::kEVPInitializationError; + } + + std::string signature(signature_length, 0); + if (EVP_DigestSign(ctx.get(), reinterpret_cast<uint8_t*>(&signature[0]), + &signature_length, data.data(), data.size()) != 1) { + return Error::Code::kCreateSignatureFailed; + } + return signature; +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/crypto/digest_sign.h b/chromium/third_party/openscreen/src/util/crypto/digest_sign.h new file mode 100644 index 00000000000..cd722ef7318 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/crypto/digest_sign.h @@ -0,0 +1,23 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_CRYPTO_DIGEST_SIGN_H_ +#define UTIL_CRYPTO_DIGEST_SIGN_H_ + +#include <openssl/evp.h> + +#include <string> + +#include "absl/types/span.h" +#include "platform/base/error.h" + +namespace openscreen { + +ErrorOr<std::string> SignData(const EVP_MD* digest, + EVP_PKEY* private_key, + absl::Span<const uint8_t> data); + +} // namespace openscreen + +#endif // UTIL_CRYPTO_DIGEST_SIGN_H_ diff --git a/chromium/third_party/openscreen/src/util/hashing.h b/chromium/third_party/openscreen/src/util/hashing.h new file mode 100644 index 00000000000..55c7cf4c4d7 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/hashing.h @@ -0,0 +1,49 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_HASHING_H_ +#define UTIL_HASHING_H_ + +namespace openscreen { + +// Computes the aggregate hash of the provided hashable objects. +// Seed must initially use a large prime between 2^63 and 2^64 as a starting +// value, or the result of a previous call to this function. +template <typename... T> +uint64_t ComputeAggregateHash(uint64_t seed, const T&... objs) { + auto hash_combiner = [](uint64_t seed, uint64_t hash_value) -> uint64_t { + static const uint64_t kMultiplier = UINT64_C(0x9ddfea08eb382d69); + uint64_t a = (hash_value ^ seed) * kMultiplier; + a ^= (a >> 47); + uint64_t b = (seed ^ a) * kMultiplier; + b ^= (b >> 47); + b *= kMultiplier; + return b; + }; + + uint64_t result = seed; + std::vector<uint64_t> hashes{std::hash<T>()(objs)...}; + for (uint64_t hash : hashes) { + result = hash_combiner(result, hash); + } + return result; +} + +template <typename... T> +uint64_t ComputeAggregateHash(const T&... objs) { + // This value is taken from absl::Hash implementation. + constexpr uint64_t default_seed = UINT64_C(0xc3a5c85c97cb3127); + return ComputeAggregateHash(default_seed, objs...); +} + +struct PairHash { + template <typename TFirst, typename TSecond> + size_t operator()(const std::pair<TFirst, TSecond>& pair) const { + return ComputeAggregateHash(pair.first, pair.second); + } +}; + +} // namespace openscreen + +#endif // UTIL_HASHING_H_ diff --git a/chromium/third_party/openscreen/src/util/json/json_reader.cc b/chromium/third_party/openscreen/src/util/json/json_reader.cc deleted file mode 100644 index e98f4aa60f5..00000000000 --- a/chromium/third_party/openscreen/src/util/json/json_reader.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "util/json/json_reader.h" - -#include <memory> -#include <string> - -#include "json/value.h" -#include "platform/base/error.h" -#include "util/logging.h" - -namespace openscreen { - -JsonReader::JsonReader() { - Json::CharReaderBuilder::strictMode(&builder_.settings_); -} - -ErrorOr<Json::Value> JsonReader::Read(absl::string_view document) { - if (document.empty()) { - return ErrorOr<Json::Value>(Error::Code::kJsonParseError, "empty document"); - } - - Json::Value root_node; - std::string error_msg; - std::unique_ptr<Json::CharReader> reader(builder_.newCharReader()); - const bool succeeded = - reader->parse(document.begin(), document.end(), &root_node, &error_msg); - if (!succeeded) { - return ErrorOr<Json::Value>(Error::Code::kJsonParseError, error_msg); - } - - return root_node; -} -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/json/json_reader.h b/chromium/third_party/openscreen/src/util/json/json_reader.h deleted file mode 100644 index cb7cded00c7..00000000000 --- a/chromium/third_party/openscreen/src/util/json/json_reader.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef UTIL_JSON_JSON_READER_H_ -#define UTIL_JSON_JSON_READER_H_ - -#include <memory> - -#include "absl/strings/string_view.h" -#include "json/reader.h" - -namespace Json { -class Value; -} - -namespace openscreen { -template <typename T> -class ErrorOr; - -class JsonReader { - public: - JsonReader(); - - ErrorOr<Json::Value> Read(absl::string_view document); - - private: - Json::CharReaderBuilder builder_; -}; - -} // namespace openscreen - -#endif // UTIL_JSON_JSON_READER_H_
\ No newline at end of file diff --git a/chromium/third_party/openscreen/src/util/json/json_writer.cc b/chromium/third_party/openscreen/src/util/json/json_serialization.cc index 65130f8e58e..42ea34f49c4 100644 --- a/chromium/third_party/openscreen/src/util/json/json_writer.cc +++ b/chromium/third_party/openscreen/src/util/json/json_serialization.cc @@ -2,36 +2,56 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "util/json/json_writer.h" +#include "util/json/json_serialization.h" #include <memory> #include <sstream> #include <string> #include <utility> -#include "json/value.h" +#include "json/reader.h" +#include "json/writer.h" #include "platform/base/error.h" #include "util/logging.h" namespace openscreen { -JsonWriter::JsonWriter() { -#ifndef _DEBUG - // Default is to "pretty print" the output JSON in a human readable - // format. On non-debug builds, we can remove pretty printing by simply - // getting rid of all indentation. - factory_["indentation"] = ""; -#endif +namespace json { + +ErrorOr<Json::Value> Parse(absl::string_view document) { + Json::CharReaderBuilder builder; + Json::CharReaderBuilder::strictMode(&builder.settings_); + if (document.empty()) { + return ErrorOr<Json::Value>(Error::Code::kJsonParseError, "empty document"); + } + + Json::Value root_node; + std::string error_msg; + std::unique_ptr<Json::CharReader> reader(builder.newCharReader()); + const bool succeeded = + reader->parse(document.begin(), document.end(), &root_node, &error_msg); + if (!succeeded) { + return ErrorOr<Json::Value>(Error::Code::kJsonParseError, error_msg); + } + + return root_node; } -ErrorOr<std::string> JsonWriter::Write(const Json::Value& value) { +ErrorOr<std::string> Stringify(const Json::Value& value) { if (value.empty()) { return ErrorOr<std::string>(Error::Code::kJsonWriteError, "Empty value"); } - std::unique_ptr<Json::StreamWriter> const writer(factory_.newStreamWriter()); - std::stringstream stream; + Json::StreamWriterBuilder factory; +#ifndef _DEBUG + // Default is to "pretty print" the output JSON in a human readable + // format. On non-debug builds, we can remove pretty printing by simply + // getting rid of all indentation. + factory["indentation"] = ""; +#endif + + std::unique_ptr<Json::StreamWriter> const writer(factory.newStreamWriter()); + std::ostringstream stream; writer->write(value, &stream); - stream << std::endl; if (!stream) { // Note: jsoncpp doesn't give us more information about what actually @@ -43,4 +63,6 @@ ErrorOr<std::string> JsonWriter::Write(const Json::Value& value) { return stream.str(); } + +} // namespace json } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/json/json_serialization.h b/chromium/third_party/openscreen/src/util/json/json_serialization.h new file mode 100644 index 00000000000..903aeb6bb1e --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_serialization.h @@ -0,0 +1,25 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_JSON_JSON_SERIALIZATION_H_ +#define UTIL_JSON_JSON_SERIALIZATION_H_ + +#include <string> + +#include "absl/strings/string_view.h" +#include "json/value.h" + +namespace openscreen { +template <typename T> +class ErrorOr; + +namespace json { + +ErrorOr<Json::Value> Parse(absl::string_view value); +ErrorOr<std::string> Stringify(const Json::Value& value); + +} // namespace json +} // namespace openscreen + +#endif // UTIL_JSON_JSON_SERIALIZATION_H_ diff --git a/chromium/third_party/openscreen/src/util/json/json_reader_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_serialization_unittest.cc index a18cf2010d5..b94fe0e38a0 100644 --- a/chromium/third_party/openscreen/src/util/json/json_reader_unittest.cc +++ b/chromium/third_party/openscreen/src/util/json/json_serialization_unittest.cc @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "util/json/json_reader.h" +#include "util/json/json_serialization.h" #include <string> -#include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/base/error.h" @@ -18,21 +17,17 @@ void AssertError(ErrorOr<Value> error_or, Error::Code code) { } } // namespace -TEST(JsonReaderTest, MalformedDocumentReturnsParseError) { - JsonReader reader; - +TEST(JsonSerializationTest, MalformedDocumentReturnsParseError) { const std::array<std::string, 4> kMalformedDocuments{ {"", "{", "{ foo: bar }", R"({"foo": "bar", "foo": baz})"}}; for (auto& document : kMalformedDocuments) { - AssertError(reader.Read(document), Error::Code::kJsonParseError); + AssertError(json::Parse(document), Error::Code::kJsonParseError); } } -TEST(JsonReaderTest, ValidEmptyDocumentParsedCorrectly) { - JsonReader reader; - - const auto actual = reader.Read("{}"); +TEST(JsonSerializationTest, ValidEmptyDocumentParsedCorrectly) { + const auto actual = json::Parse("{}"); EXPECT_TRUE(actual.is_value()); EXPECT_EQ(actual.value().getMemberNames().size(), 0u); @@ -41,13 +36,26 @@ TEST(JsonReaderTest, ValidEmptyDocumentParsedCorrectly) { // Jsoncpp has its own suite of tests ensure that things are parsed correctly, // so we only do some rudimentary checks here to make sure we didn't mangle // the value. -TEST(JsonReaderTest, ValidDocumentParsedCorrectly) { - JsonReader reader; - - const auto actual = reader.Read(R"({"foo": "bar", "baz": 1337})"); +TEST(JsonSerializationTest, ValidDocumentParsedCorrectly) { + const auto actual = json::Parse(R"({"foo": "bar", "baz": 1337})"); EXPECT_TRUE(actual.is_value()); EXPECT_EQ(actual.value().getMemberNames().size(), 2u); } +TEST(JsonSerializationTest, NullValueReturnsError) { + const auto null_value = Json::Value(); + const auto actual = json::Stringify(null_value); + + EXPECT_TRUE(actual.is_error()); + EXPECT_EQ(actual.error().code(), Error::Code::kJsonWriteError); +} + +TEST(JsonSerializationTest, ValidValueReturnsString) { + const Json::Int64 value = 31337; + const auto actual = json::Stringify(value); + + EXPECT_TRUE(actual.is_value()); + EXPECT_EQ(actual.value(), "31337"); +} } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/json/json_value.cc b/chromium/third_party/openscreen/src/util/json/json_value.cc new file mode 100644 index 00000000000..cfde5f84cf8 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_value.cc @@ -0,0 +1,43 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/json/json_value.h" + +namespace openscreen { + +absl::optional<int> MaybeGetInt(const Json::Value& message, + const char* first, + const char* last) { + const Json::Value* value = message.find(first, last); + absl::optional<int> result; + if (value && value->isInt()) { + result = value->asInt(); + } + return result; +} + +absl::optional<absl::string_view> MaybeGetString(const Json::Value& message) { + if (message.isString()) { + const char* begin = nullptr; + const char* end = nullptr; + message.getString(&begin, &end); + if (begin && end >= begin) { + return absl::string_view(begin, end - begin); + } + } + return absl::nullopt; +} + +absl::optional<absl::string_view> MaybeGetString(const Json::Value& message, + const char* first, + const char* last) { + const Json::Value* value = message.find(first, last); + absl::optional<absl::string_view> result; + if (value && value->isString()) { + return MaybeGetString(*value); + } + return result; +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/json/json_value.h b/chromium/third_party/openscreen/src/util/json/json_value.h new file mode 100644 index 00000000000..d41ea27a932 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_value.h @@ -0,0 +1,28 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_JSON_JSON_VALUE_H_ +#define UTIL_JSON_JSON_VALUE_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "json/value.h" + +#define JSON_EXPAND_FIND_CONSTANT_ARGS(s) (s), ((s) + sizeof(s) - 1) + +namespace openscreen { + +absl::optional<int> MaybeGetInt(const Json::Value& message, + const char* first, + const char* last); + +absl::optional<absl::string_view> MaybeGetString(const Json::Value& message); + +absl::optional<absl::string_view> MaybeGetString(const Json::Value& message, + const char* first, + const char* last); + +} // namespace openscreen + +#endif // UTIL_JSON_JSON_VALUE_H_ diff --git a/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc new file mode 100644 index 00000000000..04b57456279 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/json/json_value.h" + +#include "gtest/gtest.h" +#include "platform/base/error.h" +#include "util/json/json_serialization.h" + +namespace openscreen { + +TEST(JsonValueTest, GetInt) { + absl::string_view obj(R"!({"key1": 17, "key2": 32.3, "key3": "asdf"})!"); + ErrorOr<Json::Value> value_or_error = json::Parse(obj); + ASSERT_TRUE(value_or_error); + Json::Value& value = value_or_error.value(); + absl::optional<int> result1 = + MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key1")); + absl::optional<int> result2 = + MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key2")); + absl::optional<int> result3 = + MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key42")); + EXPECT_FALSE(result2); + EXPECT_FALSE(result3); + + ASSERT_TRUE(result1); + EXPECT_EQ(result1.value(), 17); +} + +TEST(JsonValueTest, GetString) { + absl::string_view obj( + R"!({"key1": 17, "key2": 32.3, "key3": "asdf", "key4": ""})!"); + ErrorOr<Json::Value> value_or_error = json::Parse(obj); + ASSERT_TRUE(value_or_error); + Json::Value& value = value_or_error.value(); + absl::optional<absl::string_view> result1 = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key3")); + absl::optional<absl::string_view> result2 = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key2")); + absl::optional<absl::string_view> result3 = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key42")); + absl::optional<absl::string_view> result4 = + MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key4")); + + EXPECT_FALSE(result2); + EXPECT_FALSE(result3); + + ASSERT_TRUE(result1); + EXPECT_EQ(result1.value(), "asdf"); + ASSERT_TRUE(result4); + EXPECT_EQ(result4.value(), ""); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/json/json_writer.h b/chromium/third_party/openscreen/src/util/json/json_writer.h deleted file mode 100644 index df37d9a067c..00000000000 --- a/chromium/third_party/openscreen/src/util/json/json_writer.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef UTIL_JSON_JSON_WRITER_H_ -#define UTIL_JSON_JSON_WRITER_H_ - -#include <memory> -#include <string> - -#include "absl/strings/string_view.h" -#include "json/writer.h" - -namespace Json { -class Value; -} - -namespace openscreen { -template <typename T> -class ErrorOr; - -class JsonWriter { - public: - JsonWriter(); - - ErrorOr<std::string> Write(const Json::Value& value); - - private: - Json::StreamWriterBuilder factory_; -}; - -} // namespace openscreen - -#endif // UTIL_JSON_JSON_WRITER_H_ diff --git a/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc deleted file mode 100644 index 8b75c82e0fc..00000000000 --- a/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "util/json/json_writer.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "platform/base/error.h" - -namespace openscreen { - -TEST(JsonWriterTest, NullValueReturnsError) { - JsonWriter writer; - - const auto null_value = Json::Value(); - const auto actual = writer.Write(null_value); - - EXPECT_TRUE(actual.is_error()); - EXPECT_EQ(actual.error().code(), Error::Code::kJsonWriteError); -} - -TEST(JsonWriterTest, ValidValueReturnsString) { - JsonWriter writer; - - const Json::Int64 value = 31337; - const auto actual = writer.Write(value); - - EXPECT_TRUE(actual.is_value()); - EXPECT_EQ(actual.value(), "31337\n"); -} -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/logging.h b/chromium/third_party/openscreen/src/util/logging.h index 5487011189f..4c76040e684 100644 --- a/chromium/third_party/openscreen/src/util/logging.h +++ b/chromium/third_party/openscreen/src/util/logging.h @@ -6,11 +6,11 @@ #define UTIL_LOGGING_H_ #include <sstream> +#include <utility> #include "platform/api/logging.h" namespace openscreen { -namespace platform { namespace internal { // The stream-based logging macros below are adapted from Chromium's @@ -21,7 +21,7 @@ class LogMessage { : level_(level), file_(file), line_(line) {} ~LogMessage() { - LogWithLevel(level_, file_, line_, stream_.str()); + LogWithLevel(level_, file_, line_, std::move(stream_)); if (level_ == LogLevel::kFatal) { Break(); } @@ -37,7 +37,7 @@ class LogMessage { // creating a copy should be safe. const char* const file_; const int line_; - std::ostringstream stream_; + std::stringstream stream_; }; // Used by the OSP_LAZY_STREAM macro to return void after evaluating an ostream @@ -48,17 +48,15 @@ class Voidify { }; } // namespace internal -} // namespace platform } // namespace openscreen #define OSP_LAZY_STREAM(condition, stream) \ - !(condition) ? (void)0 : openscreen::platform::internal::Voidify() & (stream) -#define OSP_LOG_IS_ON(level_enum) \ - openscreen::platform::IsLoggingOn( \ - openscreen::platform::LogLevel::level_enum, __FILE__) -#define OSP_LOG_STREAM(level_enum) \ - openscreen::platform::internal::LogMessage( \ - openscreen::platform::LogLevel::level_enum, __FILE__, __LINE__) \ + !(condition) ? (void)0 : openscreen::internal::Voidify() & (stream) +#define OSP_LOG_IS_ON(level_enum) \ + openscreen::IsLoggingOn(openscreen::LogLevel::level_enum, __FILE__) +#define OSP_LOG_STREAM(level_enum) \ + openscreen::internal::LogMessage(openscreen::LogLevel::level_enum, __FILE__, \ + __LINE__) \ .stream() #define OSP_VLOG \ diff --git a/chromium/third_party/openscreen/src/util/operation_loop.h b/chromium/third_party/openscreen/src/util/operation_loop.h index 155fe0783d7..ddb7846f4be 100644 --- a/chromium/third_party/openscreen/src/util/operation_loop.h +++ b/chromium/third_party/openscreen/src/util/operation_loop.h @@ -15,8 +15,6 @@ namespace openscreen { -using Clock = platform::Clock; - class OperationLoop { public: using OperationWithTimeout = std::function<void(Clock::duration)>; diff --git a/chromium/third_party/openscreen/src/util/saturate_cast.h b/chromium/third_party/openscreen/src/util/saturate_cast.h index db40660eeda..1bb373812f9 100644 --- a/chromium/third_party/openscreen/src/util/saturate_cast.h +++ b/chromium/third_party/openscreen/src/util/saturate_cast.h @@ -5,11 +5,21 @@ #ifndef UTIL_SATURATE_CAST_H_ #define UTIL_SATURATE_CAST_H_ +#include <cmath> #include <limits> #include <type_traits> namespace openscreen { +// Case 0: When To and From are the same type, saturate_cast<> is pass-through. +template <typename To, typename From> +constexpr std::enable_if_t< + std::is_same<std::remove_cv<To>, std::remove_cv<From>>::value, + To> +saturate_cast(From from) { + return from; +} + // Because of the way C++ signed versus unsigned comparison works (i.e., the // type promotion strategy employed), extra care must be taken to range-check // the input value. For example, if the current architecture is 32-bits, then @@ -21,7 +31,7 @@ namespace openscreen { // this case, the smaller of the two types will be promoted to match the // larger's size, and a valid comparison will be made. template <typename To, typename From> -constexpr typename std::enable_if_t< +constexpr std::enable_if_t< std::is_integral<From>::value && std::is_integral<To>::value && (std::is_signed<From>::value == std::is_signed<To>::value), To> @@ -37,7 +47,7 @@ saturate_cast(From from) { // Case 2: "From" is signed, but "To" is unsigned. template <typename To, typename From> -constexpr typename std::enable_if_t< +constexpr std::enable_if_t< std::is_integral<From>::value && std::is_integral<To>::value && std::is_signed<From>::value && !std::is_signed<To>::value, To> @@ -45,7 +55,7 @@ saturate_cast(From from) { if (from <= From{0}) { return To{0}; } - if (static_cast<typename std::make_unsigned_t<From>>(from) >= + if (static_cast<std::make_unsigned_t<From>>(from) >= std::numeric_limits<To>::max()) { return std::numeric_limits<To>::max(); } @@ -54,7 +64,7 @@ saturate_cast(From from) { // Case 3: "From" is unsigned, but "To" is signed. template <typename To, typename From> -constexpr typename std::enable_if_t< +constexpr std::enable_if_t< std::is_integral<From>::value && std::is_integral<To>::value && !std::is_signed<From>::value && std::is_signed<To>::value, To> @@ -66,6 +76,70 @@ saturate_cast(From from) { return static_cast<To>(from); } +// Case 4: "From" is a floating-point type, and "To" is an integer type (signed +// or unsigned). The result is truncated, per the usual C++ float-to-int +// conversion rules. +template <typename To, typename From> +constexpr std::enable_if_t<std::is_floating_point<From>::value && + std::is_integral<To>::value, + To> +saturate_cast(From from) { + // Note: It's invalid to compare the argument against + // std::numeric_limits<To>::max() because the latter, an integer value, will + // be type-promoted to the floating-point type. The problem is that the + // conversion is imprecise, as "max int" might not be exactly representable as + // a floating-point value (depending on the actual types of From and To). + // + // Thus, the strategy is to compare only floating-point values/constants to + // determine whether the bounds of the range of integers has been exceeded. + // Two assumptions here: 1) "To" is either unsigned, or is a 2's complement + // signed integer type. 2) "From" is a floating-point type that can exactly + // represent all powers of 2 within its value range. + static_assert((~To(1) + To(1)) == To(-1), "assumed 2's complement integers"); + constexpr From kMaxIntPlusOne = + From(To(1) << (std::numeric_limits<To>::digits - 1)) * From(2); + constexpr From kMaxInt = kMaxIntPlusOne - 1; + // Note: In some cases, the kMaxInt constant will equal kMaxIntPlusOne because + // there isn't an exact floating-point representation for 2^N - 1. That said, + // the following upper-bound comparison is still valid because all + // floating-point values less than 2^N would also be less than 2^N - 1. + if (from >= kMaxInt) { + return std::numeric_limits<To>::max(); + } + if (std::is_signed<To>::value) { + constexpr From kMinInt = -kMaxIntPlusOne; + if (from <= kMinInt) { + return std::numeric_limits<To>::min(); + } + } else /* if To is unsigned */ { + if (from <= From(0)) { + return To(0); + } + } + return static_cast<To>(from); +} + +// Like saturate_cast<>, but rounds to the nearest integer instead of +// truncating. +template <typename To, typename From> +constexpr std::enable_if_t<std::is_floating_point<From>::value && + std::is_integral<To>::value, + To> +rounded_saturate_cast(From from) { + const To saturated = saturate_cast<To>(from); + if (saturated == std::numeric_limits<To>::min() || + saturated == std::numeric_limits<To>::max()) { + return saturated; + } + + static_assert(sizeof(To) <= sizeof(decltype(llround(from))), + "No version of lround() for the required range of values."); + if (sizeof(To) <= sizeof(decltype(lround(from)))) { + return static_cast<To>(lround(from)); + } + return static_cast<To>(llround(from)); +} + } // namespace openscreen #endif // UTIL_SATURATE_CAST_H_ diff --git a/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc b/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc index 026d238c126..293578a09e9 100644 --- a/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc +++ b/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc @@ -175,5 +175,188 @@ TEST(SaturateCastTest, UnsignedToSigned64BitInteger) { } } +TEST(SaturateCastTest, Float32ToSigned32) { + struct ValuePair { + float from; + int32_t to; + }; + constexpr float kFloatMax = std::numeric_limits<float>::max(); + // Note: kIntMax is one larger because float cannot represent the exact value. + constexpr float kIntMax = + static_cast<float>(std::numeric_limits<int32_t>::max()); + constexpr float kIntMin = std::numeric_limits<int32_t>::min(); + const ValuePair kValuePairs[] = { + {kFloatMax, std::numeric_limits<int32_t>::max()}, + {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int32_t>::max()}, + {kIntMax, std::numeric_limits<int32_t>::max()}, + {std::nextafter(kIntMax, 0.f), 2147483520}, + {42, 42}, + {0, 0}, + {-42, -42}, + {std::nextafter(kIntMin, 0.f), -2147483520}, + {kIntMin, std::numeric_limits<int32_t>::min()}, + {std::nextafter(kIntMin, -kFloatMax), + std::numeric_limits<int32_t>::min()}, + {-kFloatMax, std::numeric_limits<int32_t>::min()}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, saturate_cast<int32_t>(value_pair.from)); + } +} + +TEST(SaturateCastTest, Float32ToSigned64) { + struct ValuePair { + float from; + int64_t to; + }; + constexpr float kFloatMax = std::numeric_limits<float>::max(); + // Note: kIntMax is one larger because float cannot represent the exact value. + constexpr float kIntMax = + static_cast<float>(std::numeric_limits<int64_t>::max()); + constexpr float kIntMin = std::numeric_limits<int64_t>::min(); + const ValuePair kValuePairs[] = { + {kFloatMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int64_t>::max()}, + {kIntMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, 0.f), INT64_C(9223371487098961920)}, + {42, 42}, + {0, 0}, + {-42, -42}, + {std::nextafter(kIntMin, 0.f), INT64_C(-9223371487098961920)}, + {kIntMin, std::numeric_limits<int64_t>::min()}, + {std::nextafter(kIntMin, -kFloatMax), + std::numeric_limits<int64_t>::min()}, + {-kFloatMax, std::numeric_limits<int64_t>::min()}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, saturate_cast<int64_t>(value_pair.from)); + } +} + +TEST(SaturateCastTest, Float64ToSigned32) { + struct ValuePair { + double from; + int32_t to; + }; + constexpr double kDoubleMax = std::numeric_limits<double>::max(); + constexpr double kIntMax = std::numeric_limits<int32_t>::max(); + constexpr double kIntMin = std::numeric_limits<int32_t>::min(); + const ValuePair kValuePairs[] = { + {kDoubleMax, std::numeric_limits<int32_t>::max()}, + {std::nextafter(kIntMax, kDoubleMax), + std::numeric_limits<int32_t>::max()}, + {kIntMax, std::numeric_limits<int32_t>::max()}, + {std::nextafter(kIntMax, 0.0), std::numeric_limits<int32_t>::max() - 1}, + {42, 42}, + {0, 0}, + {-42, -42}, + {std::nextafter(kIntMin, 0.0), std::numeric_limits<int32_t>::min() + 1}, + {kIntMin, std::numeric_limits<int32_t>::min()}, + {std::nextafter(kIntMin, -kDoubleMax), + std::numeric_limits<int32_t>::min()}, + {-kDoubleMax, std::numeric_limits<int32_t>::min()}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, saturate_cast<int32_t>(value_pair.from)); + } +} + +TEST(SaturateCastTest, Float64ToSigned64) { + struct ValuePair { + double from; + int64_t to; + }; + constexpr double kDoubleMax = std::numeric_limits<double>::max(); + // Note: kIntMax is one larger because double cannot represent the exact + // value. + constexpr double kIntMax = + static_cast<double>(std::numeric_limits<int64_t>::max()); + constexpr double kIntMin = std::numeric_limits<int64_t>::min(); + const ValuePair kValuePairs[] = { + {kDoubleMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, kDoubleMax), + std::numeric_limits<int64_t>::max()}, + {kIntMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, 0.0), INT64_C(9223372036854774784)}, + {42, 42}, + {0, 0}, + {-42, -42}, + {std::nextafter(kIntMin, 0.0), INT64_C(-9223372036854774784)}, + {kIntMin, std::numeric_limits<int64_t>::min()}, + {std::nextafter(kIntMin, -kDoubleMax), + std::numeric_limits<int64_t>::min()}, + {-kDoubleMax, std::numeric_limits<int64_t>::min()}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, saturate_cast<int64_t>(value_pair.from)); + } +} + +TEST(SaturateCastTest, Float32ToUnsigned64) { + struct ValuePair { + float from; + uint64_t to; + }; + constexpr float kFloatMax = std::numeric_limits<float>::max(); + // Note: kIntMax is one larger because float cannot represent the exact value. + constexpr float kIntMax = + static_cast<float>(std::numeric_limits<uint64_t>::max()); + const ValuePair kValuePairs[] = { + {kFloatMax, std::numeric_limits<uint64_t>::max()}, + {std::nextafter(kIntMax, kFloatMax), + std::numeric_limits<uint64_t>::max()}, + {kIntMax, std::numeric_limits<uint64_t>::max()}, + {std::nextafter(kIntMax, 0.f), UINT64_C(18446742974197923840)}, + {42, 42}, + {0, 0}, + {-42, 0}, + {-kFloatMax, 0}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, saturate_cast<uint64_t>(value_pair.from)); + } +} + +TEST(SaturateCastTest, RoundingFloat32ToSigned64) { + struct ValuePair { + float from; + int64_t to; + }; + constexpr float kFloatMax = std::numeric_limits<float>::max(); + // Note: kIntMax is one larger because float cannot represent the exact value. + constexpr float kIntMax = + static_cast<float>(std::numeric_limits<int64_t>::max()); + constexpr float kIntMin = std::numeric_limits<int64_t>::min(); + const ValuePair kValuePairs[] = { + {kFloatMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int64_t>::max()}, + {kIntMax, std::numeric_limits<int64_t>::max()}, + {std::nextafter(kIntMax, 0.f), INT64_C(9223371487098961920)}, + {41.9, 42}, + {42, 42}, + {42.6, 43}, + {42.5, 43}, + {42.4, 42}, + {0.5, 1}, + {0.1, 0}, + {0, 0}, + {-0.1, 0}, + {-0.5, -1}, + {-41.9, -42}, + {-42, -42}, + {-42.4, -42}, + {-42.5, -43}, + {-42.6, -43}, + {std::nextafter(kIntMin, 0.f), INT64_C(-9223371487098961920)}, + {kIntMin, std::numeric_limits<int64_t>::min()}, + {std::nextafter(kIntMin, -kFloatMax), + std::numeric_limits<int64_t>::min()}, + {-kFloatMax, std::numeric_limits<int64_t>::min()}, + }; + for (const ValuePair& value_pair : kValuePairs) { + EXPECT_EQ(value_pair.to, rounded_saturate_cast<int64_t>(value_pair.from)); + } +} + } // namespace } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/serial_delete_ptr.h b/chromium/third_party/openscreen/src/util/serial_delete_ptr.h index d33781ee361..1aff24959d5 100644 --- a/chromium/third_party/openscreen/src/util/serial_delete_ptr.h +++ b/chromium/third_party/openscreen/src/util/serial_delete_ptr.h @@ -19,24 +19,30 @@ namespace openscreen { template <typename Type, typename DeleterType> class SerialDelete { public: - explicit SerialDelete(platform::TaskRunner* task_runner) + SerialDelete() : deleter_() {} + + explicit SerialDelete(TaskRunner* task_runner) : task_runner_(task_runner), deleter_() { assert(task_runner); } template <typename DT> - SerialDelete(platform::TaskRunner* task_runner, DT&& deleter) + SerialDelete(TaskRunner* task_runner, DT&& deleter) : task_runner_(task_runner), deleter_(std::forward<DT>(deleter)) { assert(task_runner); } void operator()(Type* pointer) const { - // Deletion of the object depends on the task being run by the task runner. - task_runner_->PostTask([pointer, deleter = deleter_] { deleter(pointer); }); + if (task_runner_) { + // Deletion of the object depends on the task being run by the task + // runner. + task_runner_->PostTask( + [pointer, deleter = std::move(deleter_)] { deleter(pointer); }); + } } private: - platform::TaskRunner* task_runner_; + TaskRunner* task_runner_; DeleterType deleter_; }; @@ -46,21 +52,26 @@ template <typename Type, typename DeleterType = std::default_delete<Type>> class SerialDeletePtr : public std::unique_ptr<Type, SerialDelete<Type, DeleterType>> { public: - explicit SerialDeletePtr(platform::TaskRunner* task_runner) noexcept + SerialDeletePtr() noexcept + : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>( + nullptr, + SerialDelete<Type, DeleterType>()) {} + + explicit SerialDeletePtr(TaskRunner* task_runner) noexcept : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>( nullptr, SerialDelete<Type, DeleterType>(task_runner)) { assert(task_runner); } - SerialDeletePtr(platform::TaskRunner* task_runner, std::nullptr_t) noexcept + SerialDeletePtr(TaskRunner* task_runner, std::nullptr_t) noexcept : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>( nullptr, SerialDelete<Type, DeleterType>(task_runner)) { assert(task_runner); } - SerialDeletePtr(platform::TaskRunner* task_runner, Type* pointer) noexcept + SerialDeletePtr(TaskRunner* task_runner, Type* pointer) noexcept : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>( pointer, SerialDelete<Type, DeleterType>(task_runner)) { @@ -68,7 +79,7 @@ class SerialDeletePtr } SerialDeletePtr( - platform::TaskRunner* task_runner, + TaskRunner* task_runner, Type* pointer, typename std::conditional<std::is_reference<DeleterType>::value, DeleterType, @@ -80,7 +91,7 @@ class SerialDeletePtr } SerialDeletePtr( - platform::TaskRunner* task_runner, + TaskRunner* task_runner, Type* pointer, typename std::remove_reference<DeleterType>::type&& deleter) noexcept : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>( @@ -91,7 +102,7 @@ class SerialDeletePtr }; template <typename Type, typename... Args> -SerialDeletePtr<Type> MakeSerialDelete(platform::TaskRunner* task_runner, +SerialDeletePtr<Type> MakeSerialDelete(TaskRunner* task_runner, Args&&... args) { return SerialDeletePtr<Type>(task_runner, new Type(std::forward<Args>(args)...)); diff --git a/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc b/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc index 4713814cd85..c8b899c5d36 100644 --- a/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc +++ b/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc @@ -10,10 +10,6 @@ namespace openscreen { -using openscreen::platform::Clock; -using openscreen::platform::FakeClock; -using openscreen::platform::FakeTaskRunner; - class SerialDeletePtrTest : public ::testing::Test { public: SerialDeletePtrTest() : clock_(Clock::now()), task_runner_(&clock_) {} diff --git a/chromium/third_party/openscreen/src/util/simple_fraction.cc b/chromium/third_party/openscreen/src/util/simple_fraction.cc new file mode 100644 index 00000000000..b45e1982eff --- /dev/null +++ b/chromium/third_party/openscreen/src/util/simple_fraction.cc @@ -0,0 +1,69 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/simple_fraction.h" + +#include <cmath> +#include <limits> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "util/logging.h" + +namespace openscreen { + +// static +ErrorOr<SimpleFraction> SimpleFraction::FromString(absl::string_view value) { + std::vector<absl::string_view> fields = absl::StrSplit(value, '/'); + if (fields.size() != 1 && fields.size() != 2) { + return Error::Code::kParameterInvalid; + } + + int numerator; + int denominator = 1; + if (!absl::SimpleAtoi(fields[0], &numerator)) { + return Error::Code::kParameterInvalid; + } + + if (fields.size() == 2) { + if (!absl::SimpleAtoi(fields[1], &denominator)) { + return Error::Code::kParameterInvalid; + } + } + + return SimpleFraction{numerator, denominator}; +} + +std::string SimpleFraction::ToString() const { + if (denominator == 1) { + return std::to_string(numerator); + } + return absl::StrCat(numerator, "/", denominator); +} + +bool SimpleFraction::operator==(const SimpleFraction& other) const { + return numerator == other.numerator && denominator == other.denominator; +} + +bool SimpleFraction::operator!=(const SimpleFraction& other) const { + return !(*this == other); +} + +bool SimpleFraction::is_defined() const { + return denominator != 0; +} + +bool SimpleFraction::is_positive() const { + return is_defined() && (numerator >= 0) && (denominator > 0); +} + +SimpleFraction::operator double() const { + if (denominator == 0) { + return nan(""); + } + return static_cast<double>(numerator) / static_cast<double>(denominator); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/simple_fraction.h b/chromium/third_party/openscreen/src/util/simple_fraction.h new file mode 100644 index 00000000000..f8ab50832f9 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/simple_fraction.h @@ -0,0 +1,45 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_SIMPLE_FRACTION_H_ +#define UTIL_SIMPLE_FRACTION_H_ + +#include <string> + +#include "absl/strings/string_view.h" +#include "platform/base/error.h" + +namespace openscreen { + +// SimpleFraction is used to represent simple (or "common") fractions, composed +// of a rational number written a/b where a and b are both integers. + +// Note: Since SimpleFraction is a trivial type, it comes with a +// default constructor and is copyable, as well as allowing static +// initialization. + +// Some helpful notes on SimpleFraction assumptions/limitations: +// 1. SimpleFraction does not perform reductions. 2/4 != 1/2, and -1/-1 != 1/1. +// 2. denominator = 0 is considered undefined. +// 3. numerator = saturates range to int min or int max +// 4. A SimpleFraction is "positive" if and only if it is defined and at least +// equal to zero. Since reductions are not performed, -1/-1 is negative. +struct SimpleFraction { + static ErrorOr<SimpleFraction> FromString(absl::string_view value); + std::string ToString() const; + + bool operator==(const SimpleFraction& other) const; + bool operator!=(const SimpleFraction& other) const; + + bool is_defined() const; + bool is_positive() const; + explicit operator double() const; + + int numerator = 0; + int denominator = 0; +}; + +} // namespace openscreen + +#endif // UTIL_SIMPLE_FRACTION_H_ diff --git a/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc new file mode 100644 index 00000000000..7cdbfeeccad --- /dev/null +++ b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc @@ -0,0 +1,98 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/simple_fraction.h" + +#include <limits> + +#include "gtest/gtest.h" + +namespace openscreen { + +namespace { + +constexpr int kMin = std::numeric_limits<int>::min(); +constexpr int kMax = std::numeric_limits<int>::max(); + +void ExpectFromStringEquals(absl::string_view s, + const SimpleFraction& expected) { + const ErrorOr<SimpleFraction> f = SimpleFraction::FromString(s); + EXPECT_TRUE(f.is_value()); + EXPECT_EQ(expected, f.value()); +} + +void ExpectFromStringError(absl::string_view s) { + const auto f = SimpleFraction::FromString(s); + EXPECT_TRUE(f.is_error()); +} +} // namespace + +TEST(SimpleFractionTest, FromStringParsesCorrectFractions) { + ExpectFromStringEquals("1/2", SimpleFraction{1, 2}); + ExpectFromStringEquals("99/3", SimpleFraction{99, 3}); + ExpectFromStringEquals("-1/2", SimpleFraction{-1, 2}); + ExpectFromStringEquals("-13/-37", SimpleFraction{-13, -37}); + ExpectFromStringEquals("1/0", SimpleFraction{1, 0}); + ExpectFromStringEquals("1", SimpleFraction{1, 1}); + ExpectFromStringEquals("0", SimpleFraction{0, 1}); + ExpectFromStringEquals("-20", SimpleFraction{-20, 1}); + ExpectFromStringEquals("100", SimpleFraction{100, 1}); +} + +TEST(SimpleFractionTest, FromStringErrorsOnInvalid) { + ExpectFromStringError(""); + ExpectFromStringError("/"); + ExpectFromStringError("1/"); + ExpectFromStringError("/1"); + ExpectFromStringError("888/"); + ExpectFromStringError("not a fraction at all"); +} + +TEST(SimpleFractionTest, Equality) { + EXPECT_EQ((SimpleFraction{1, 2}), (SimpleFraction{1, 2})); + EXPECT_EQ((SimpleFraction{1, 0}), (SimpleFraction{1, 0})); + EXPECT_NE((SimpleFraction{1, 2}), (SimpleFraction{1, 3})); + + // We currently don't do any reduction. + EXPECT_NE((SimpleFraction{2, 4}), (SimpleFraction{1, 2})); + EXPECT_NE((SimpleFraction{9, 10}), (SimpleFraction{-9, -10})); +} + +TEST(SimpleFractionTest, Definition) { + EXPECT_TRUE((SimpleFraction{kMin, 1}).is_defined()); + EXPECT_TRUE((SimpleFraction{kMax, 1}).is_defined()); + + EXPECT_FALSE((SimpleFraction{kMin, 0}).is_defined()); + EXPECT_FALSE((SimpleFraction{kMax, 0}).is_defined()); + EXPECT_FALSE((SimpleFraction{0, 0}).is_defined()); + EXPECT_FALSE((SimpleFraction{-0, -0}).is_defined()); +} + +TEST(SimpleFractionTest, Positivity) { + EXPECT_TRUE((SimpleFraction{1234, 20}).is_positive()); + EXPECT_TRUE((SimpleFraction{kMax - 1, 20}).is_positive()); + EXPECT_TRUE((SimpleFraction{0, kMax}).is_positive()); + EXPECT_TRUE((SimpleFraction{kMax, 1}).is_positive()); + + // Since C++ doesn't have a truly negative zero, this is positive. + EXPECT_TRUE((SimpleFraction{-0, 1}).is_positive()); + + EXPECT_FALSE((SimpleFraction{0, kMin}).is_positive()); + EXPECT_FALSE((SimpleFraction{-0, -1}).is_positive()); + EXPECT_FALSE((SimpleFraction{kMin + 1, 20}).is_positive()); + EXPECT_FALSE((SimpleFraction{kMin, 1}).is_positive()); + EXPECT_FALSE((SimpleFraction{kMin, 0}).is_positive()); + EXPECT_FALSE((SimpleFraction{kMax, 0}).is_positive()); + EXPECT_FALSE((SimpleFraction{0, 0}).is_positive()); + EXPECT_FALSE((SimpleFraction{-0, -0}).is_positive()); +} + +TEST(SimpleFractionTest, CastToDouble) { + EXPECT_DOUBLE_EQ(0.0, static_cast<double>(SimpleFraction{0, 1})); + EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{1, 1})); + EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMax, kMax})); + EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMin, kMin})); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/stringprintf.cc b/chromium/third_party/openscreen/src/util/stringprintf.cc new file mode 100644 index 00000000000..452c32824fc --- /dev/null +++ b/chromium/third_party/openscreen/src/util/stringprintf.cc @@ -0,0 +1,21 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/stringprintf.h" + +#include <iomanip> +#include <sstream> + +namespace openscreen { + +std::string HexEncode(absl::Span<const uint8_t> bytes) { + std::ostringstream hex_dump; + hex_dump << std::setfill('0') << std::hex; + for (const uint8_t byte : bytes) { + hex_dump << std::setw(2) << static_cast<int>(byte); + } + return hex_dump.str(); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/stringprintf.h b/chromium/third_party/openscreen/src/util/stringprintf.h index 93e5eb93677..df24ac042a5 100644 --- a/chromium/third_party/openscreen/src/util/stringprintf.h +++ b/chromium/third_party/openscreen/src/util/stringprintf.h @@ -5,7 +5,12 @@ #ifndef UTIL_STRINGPRINTF_H_ #define UTIL_STRINGPRINTF_H_ +#include <stdint.h> + #include <ostream> +#include <string> + +#include "absl/types/span.h" namespace openscreen { @@ -36,6 +41,9 @@ void PrettyPrintAsciiHex(std::ostream& os, It first, It last) { } } +// Returns a hex string representation of the given |bytes|. +std::string HexEncode(absl::Span<const uint8_t> bytes); + } // namespace openscreen #endif // UTIL_STRINGPRINTF_H_ diff --git a/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc b/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc new file mode 100644 index 00000000000..df0d60f8b41 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc @@ -0,0 +1,24 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/stringprintf.h" + +#include "gtest/gtest.h" + +namespace openscreen { +namespace { + +TEST(HexEncode, ProducesEmptyStringFromEmptyByteArray) { + const uint8_t kSomeMemoryLocation = 0; + EXPECT_EQ("", HexEncode(absl::Span<const uint8_t>(&kSomeMemoryLocation, 0))); +} + +TEST(HexEncode, ProducesHexStringsFromBytes) { + const uint8_t kMessage[] = "Hello world!"; + const char kMessageInHex[] = "48656c6c6f20776f726c642100"; + EXPECT_EQ(kMessageInHex, HexEncode(kMessage)); +} + +} // namespace +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/trace_logging.h b/chromium/third_party/openscreen/src/util/trace_logging.h index cb454e1b296..527bfcb2ac1 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging.h +++ b/chromium/third_party/openscreen/src/util/trace_logging.h @@ -28,29 +28,31 @@ #include "util/trace_logging/macro_support.h" #undef INCLUDING_FROM_UTIL_TRACE_LOGGING_H_ -#define TRACE_SET_RESULT(result) \ - do { \ - if (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any)) { \ - openscreen::internal::ScopedTraceOperation::set_result(result); \ - } \ +#define TRACE_SET_RESULT(result) \ + do { \ + if (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny)) { \ + openscreen::internal::ScopedTraceOperation::set_result(result); \ + } \ } while (false) #define TRACE_SET_HIERARCHY(ids) TRACE_SET_HIERARCHY_INTERNAL(__LINE__, ids) -#define TRACE_HIERARCHY \ - (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \ - ? openscreen::internal::ScopedTraceOperation::hierarchy() \ - : openscreen::platform::TraceIdHierarchy::Empty()) -#define TRACE_CURRENT_ID \ - (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \ - ? openscreen::internal::ScopedTraceOperation::current_id() \ +#define TRACE_HIERARCHY \ + (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \ + ? openscreen::internal::ScopedTraceOperation::hierarchy() \ + : openscreen::TraceIdHierarchy::Empty()) +#define TRACE_CURRENT_ID \ + (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \ + ? openscreen::internal::ScopedTraceOperation::current_id() \ : kEmptyTraceId) -#define TRACE_ROOT_ID \ - (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \ - ? openscreen::internal::ScopedTraceOperation::root_id() \ +#define TRACE_ROOT_ID \ + (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \ + ? openscreen::internal::ScopedTraceOperation::root_id() \ : kEmptyTraceId) -// Synchronous Trace Macro. +// Synchronous Trace Macros. #define TRACE_SCOPED(category, name, ...) \ TRACE_SCOPED_INTERNAL(__LINE__, category, name, ##__VA_ARGS__) +#define TRACE_DEFAULT_SCOPED(category, ...) \ + TRACE_SCOPED(category, __PRETTY_FUNCTION__, ##__VA_ARGS__) // Asynchronous Trace Macros. #define TRACE_ASYNC_START(category, name, ...) \ @@ -76,9 +78,9 @@ inline void DoNothingForTracing(Args... args) {} #define TRACE_SET_RESULT(result) \ openscreen::internal::DoNothingForTracing(result) #define TRACE_SET_HIERARCHY(ids) openscreen::internal::DoNothingForTracing(ids) -#define TRACE_HIERARCHY openscreen::platform::TraceIdHierarchy::Empty() -#define TRACE_CURRENT_ID openscreen::platform::kEmptyTraceId -#define TRACE_ROOT_ID openscreen::platform::kEmptyTraceId +#define TRACE_HIERARCHY openscreen::TraceIdHierarchy::Empty() +#define TRACE_CURRENT_ID openscreen::kEmptyTraceId +#define TRACE_ROOT_ID openscreen::kEmptyTraceId #define TRACE_SCOPED(category, name, ...) \ openscreen::internal::DoNothingForTracing(category, name, ##__VA_ARGS__) #define TRACE_ASYNC_START(category, name, ...) \ diff --git a/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h b/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h index 30865f4c9c9..d265081f428 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h +++ b/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h @@ -42,8 +42,8 @@ namespace openscreen { namespace internal { -inline bool IsTraceLoggingEnabled(platform::TraceCategory::Value category) { - auto* const destination = platform::GetTracingDestination(); +inline bool IsTraceLoggingEnabled(TraceCategory::Value category) { + const CurrentTracingDestination destination; return destination && destination->IsTraceLoggingEnabled(category); } @@ -59,7 +59,7 @@ inline bool IsTraceLoggingEnabled(platform::TraceCategory::Value category) { tracing_storage, line)[sizeof(openscreen::internal::TraceIdSetter)]; \ TRACE_INTERNAL_IGNORE_UNUSED_VAR \ const auto TRACE_INTERNAL_UNIQUE_VAR_NAME(trace_ref_) = \ - TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \ + TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \ ? openscreen::internal::TraceInstanceHelper< \ openscreen::internal::TraceIdSetter>:: \ Create(TRACE_INTERNAL_CONCAT_CONST(tracing_storage, line), \ diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc index 7b19b74ddad..22c3f5b35c4 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc +++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc @@ -12,11 +12,6 @@ #if defined(ENABLE_TRACE_LOGGING) -using openscreen::platform::kUnsetTraceId; -using openscreen::platform::TraceCategory; -using openscreen::platform::TraceId; -using openscreen::platform::TraceIdHierarchy; - namespace openscreen { namespace internal { @@ -25,13 +20,13 @@ bool ScopedTraceOperation::TraceAsyncEnd(const uint32_t line, const char* file, TraceId id, Error::Code e) { - auto end_time = platform::Clock::now(); - auto* const current_platform = platform::GetTracingDestination(); - if (current_platform == nullptr) { - return false; + auto end_time = Clock::now(); + const CurrentTracingDestination destination; + if (destination) { + destination->LogAsyncEnd(line, file, end_time, id, e); + return true; } - current_platform->LogAsyncEnd(line, file, end_time, id, e); - return true; + return false; } ScopedTraceOperation::ScopedTraceOperation(TraceId trace_id, @@ -95,7 +90,7 @@ TraceLoggerBase::TraceLoggerBase(TraceCategory::Value category, TraceId parent, TraceId root) : ScopedTraceOperation(current, parent, root), - start_time_(platform::Clock::now()), + start_time_(Clock::now()), result_(Error::Code::kNone), name_(name), file_name_(file), @@ -116,24 +111,22 @@ TraceLoggerBase::TraceLoggerBase(TraceCategory::Value category, ids.root) {} SynchronousTraceLogger::~SynchronousTraceLogger() { - auto* const current_platform = platform::GetTracingDestination(); - if (current_platform == nullptr) { - return; + const CurrentTracingDestination destination; + if (destination) { + auto end_time = Clock::now(); + destination->LogTrace(this->name_, this->line_number_, this->file_name_, + this->start_time_, end_time, this->to_hierarchy(), + this->result_); } - auto end_time = platform::Clock::now(); - current_platform->LogTrace(this->name_, this->line_number_, this->file_name_, - this->start_time_, end_time, this->to_hierarchy(), - this->result_); } AsynchronousTraceLogger::~AsynchronousTraceLogger() { - auto* const current_platform = platform::GetTracingDestination(); - if (current_platform == nullptr) { - return; + const CurrentTracingDestination destination; + if (destination) { + destination->LogAsyncStart(this->name_, this->line_number_, + this->file_name_, this->start_time_, + this->to_hierarchy()); } - current_platform->LogAsyncStart(this->name_, this->line_number_, - this->file_name_, this->start_time_, - this->to_hierarchy()); } TraceIdSetter::~TraceIdSetter() = default; diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h index f8af42da926..692fbf55fe9 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h +++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h @@ -7,7 +7,9 @@ #include <atomic> #include <cstring> +#include <memory> #include <stack> +#include <utility> #include <vector> #include "build/config/features.h" @@ -34,19 +36,17 @@ class ScopedTraceOperation { // Getters the current Trace Hierarchy. If the traces_ stack hasn't been // created yet, return as if the empty root node is there. - static platform::TraceId current_id() { - return traces_ == nullptr ? platform::kEmptyTraceId - : traces_->top()->trace_id_; + static TraceId current_id() { + return traces_ == nullptr ? kEmptyTraceId : traces_->top()->trace_id_; } - static platform::TraceId root_id() { - return traces_ == nullptr ? platform::kEmptyTraceId - : traces_->top()->root_id_; + static TraceId root_id() { + return traces_ == nullptr ? kEmptyTraceId : traces_->top()->root_id_; } - static platform::TraceIdHierarchy hierarchy() { + static TraceIdHierarchy hierarchy() { if (traces_ == nullptr) { - return platform::TraceIdHierarchy::Empty(); + return TraceIdHierarchy::Empty(); } return traces_->top()->to_hierarchy(); @@ -66,7 +66,7 @@ class ScopedTraceOperation { // the ternary operator in the macros simpler. static bool TraceAsyncEnd(const uint32_t line, const char* file, - platform::TraceId id, + TraceId id, Error::Code e); protected: @@ -77,18 +77,16 @@ class ScopedTraceOperation { virtual void SetTraceResult(Error::Code error) = 0; // Constructor to set all trace id information. - ScopedTraceOperation(platform::TraceId current_id = platform::kUnsetTraceId, - platform::TraceId parent_id = platform::kUnsetTraceId, - platform::TraceId root_id = platform::kUnsetTraceId); + ScopedTraceOperation(TraceId current_id = kUnsetTraceId, + TraceId parent_id = kUnsetTraceId, + TraceId root_id = kUnsetTraceId); // Current TraceId information. - platform::TraceId trace_id_; - platform::TraceId parent_id_; - platform::TraceId root_id_; + TraceId trace_id_; + TraceId parent_id_; + TraceId root_id_; - platform::TraceIdHierarchy to_hierarchy() { - return {trace_id_, parent_id_, root_id_}; - } + TraceIdHierarchy to_hierarchy() { return {trace_id_, parent_id_, root_id_}; } private: // NOTE: A std::vector is used for backing the stack because it provides the @@ -113,26 +111,26 @@ class ScopedTraceOperation { // The class which does actual trace logging. class TraceLoggerBase : public ScopedTraceOperation { public: - TraceLoggerBase(platform::TraceCategory::Value category, + TraceLoggerBase(TraceCategory::Value category, const char* name, const char* file, uint32_t line, - platform::TraceId current = platform::kUnsetTraceId, - platform::TraceId parent = platform::kUnsetTraceId, - platform::TraceId root = platform::kUnsetTraceId); + TraceId current = kUnsetTraceId, + TraceId parent = kUnsetTraceId, + TraceId root = kUnsetTraceId); - TraceLoggerBase(platform::TraceCategory::Value category, + TraceLoggerBase(TraceCategory::Value category, const char* name, const char* file, uint32_t line, - platform::TraceIdHierarchy ids); + TraceIdHierarchy ids); protected: // Set the result. void SetTraceResult(Error::Code error) override { result_ = error; } // Timestamp for when the object was created. - platform::Clock::time_point start_time_; + Clock::time_point start_time_; // Result of this operation. Error::Code result_; @@ -147,7 +145,7 @@ class TraceLoggerBase : public ScopedTraceOperation { uint32_t line_number_; // Category of this trace log. - platform::TraceCategory::Value category_; + TraceCategory::Value category_; private: OSP_DISALLOW_COPY_AND_ASSIGN(TraceLoggerBase); @@ -175,9 +173,9 @@ class AsynchronousTraceLogger : public TraceLoggerBase { // Inserts a fake element into the ScopedTraceOperation stack to set // the current TraceId Hierarchy manually. -class TraceIdSetter : public ScopedTraceOperation { +class TraceIdSetter final : public ScopedTraceOperation { public: - explicit TraceIdSetter(platform::TraceIdHierarchy ids) + explicit TraceIdSetter(TraceIdHierarchy ids) : ScopedTraceOperation(ids.current, ids.parent, ids.root) {} ~TraceIdSetter() final; diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc index 63fa1cba5d3..7e52a6f48c3 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc +++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc @@ -21,12 +21,9 @@ using ::testing::_; using ::testing::DoAll; using ::testing::Invoke; -using platform::kEmptyTraceId; -using platform::MockLoggingPlatform; - // These tests validate that parameters are passed correctly by using the Trace // Internals. -constexpr auto category = platform::TraceCategory::mDNS; +constexpr auto category = TraceCategory::kMdns; constexpr uint32_t line = 10; TEST(TraceLoggingInternalTest, CreatingNoTraceObjectValid) { @@ -38,15 +35,13 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationTrue) { MockLoggingPlatform platform; EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .Times(1) - .WillOnce( - DoAll(Invoke(platform::ValidateTraceTimestampDiff<delay_in_ms>), - Invoke(platform::ValidateTraceErrorCode<Error::Code::kNone>))); + .WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>), + Invoke(ValidateTraceErrorCode<Error::Code::kNone>))); { uint8_t temp[sizeof(SynchronousTraceLogger)]; - auto ptr = true ? TraceInstanceHelper<SynchronousTraceLogger>::Create( - temp, category, "Name", __FILE__, line) - : TraceInstanceHelper<SynchronousTraceLogger>::Empty(); + auto ptr = TraceInstanceHelper<SynchronousTraceLogger>::Create( + temp, category, "Name", __FILE__, line); std::this_thread::sleep_for(std::chrono::milliseconds(delay_in_ms)); auto ids = ScopedTraceOperation::hierarchy(); EXPECT_NE(ids.current, kEmptyTraceId); @@ -62,10 +57,7 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationFalse) { EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(0); { - uint8_t temp[sizeof(SynchronousTraceLogger)]; - auto ptr = false ? TraceInstanceHelper<SynchronousTraceLogger>::Create( - temp, category, "Name", __FILE__, line) - : TraceInstanceHelper<SynchronousTraceLogger>::Empty(); + auto ptr = TraceInstanceHelper<SynchronousTraceLogger>::Empty(); auto ids = ScopedTraceOperation::hierarchy(); EXPECT_EQ(ids.current, kEmptyTraceId); EXPECT_EQ(ids.parent, kEmptyTraceId); @@ -81,7 +73,7 @@ TEST(TraceLoggingInternalTest, ExpectParametersPassedToResult) { MockLoggingPlatform platform; EXPECT_CALL(platform, LogTrace(testing::StrEq("Name"), line, testing::StrEq(__FILE__), _, _, _, _)) - .WillOnce(Invoke(platform::ValidateTraceErrorCode<Error::Code::kNone>)); + .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); { SynchronousTraceLogger{category, "Name", __FILE__, line}; } } diff --git a/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc b/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc index d34ed303f93..386b19e6424 100644 --- a/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc +++ b/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc @@ -13,7 +13,6 @@ #include "platform/test/trace_logging_helpers.h" namespace openscreen { -namespace platform { namespace { #if defined(ENABLE_TRACE_LOGGING) @@ -38,37 +37,48 @@ using StrictMockLoggingPlatform = ::testing::StrictMock<MockLoggingPlatform>; TEST(TraceLoggingTest, MacroCallScopedDoesntSegFault) { StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1); #endif - { TRACE_SCOPED(TraceCategory::Value::Any, "test"); } + { TRACE_SCOPED(TraceCategory::Value::kAny, "test"); } +} + +TEST(TraceLoggingTest, MacroCallDefaultScopedDoesntSegFault) { + StrictMockLoggingPlatform platform; +#if defined(ENABLE_TRACE_LOGGING) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) + .Times(AtLeast(1)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1); +#endif + { TRACE_DEFAULT_SCOPED(TraceCategory::Value::kAny); } } TEST(TraceLoggingTest, MacroCallUnscopedDoesntSegFault) { StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1); #endif - { TRACE_ASYNC_START(TraceCategory::Value::Any, "test"); } + { TRACE_ASYNC_START(TraceCategory::Value::kAny, "test"); } } TEST(TraceLoggingTest, MacroVariablesUniquelyNames) { StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(2); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(3); EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(2); #endif { - TRACE_SCOPED(TraceCategory::Value::Any, "test1"); - TRACE_SCOPED(TraceCategory::Value::Any, "test2"); - TRACE_ASYNC_START(TraceCategory::Value::Any, "test3"); - TRACE_ASYNC_START(TraceCategory::Value::Any, "test4"); + TRACE_SCOPED(TraceCategory::Value::kAny, "test1"); + TRACE_SCOPED(TraceCategory::Value::kAny, "test2"); + TRACE_ASYNC_START(TraceCategory::Value::kAny, "test3"); + TRACE_ASYNC_START(TraceCategory::Value::kAny, "test4"); + TRACE_DEFAULT_SCOPED(TraceCategory::Value::kAny); } } @@ -76,7 +86,7 @@ TEST(TraceLoggingTest, ExpectTimestampsReflectDelay) { constexpr uint32_t delay_in_ms = 50; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>), @@ -84,7 +94,7 @@ TEST(TraceLoggingTest, ExpectTimestampsReflectDelay) { #endif { - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); std::this_thread::sleep_for(std::chrono::milliseconds(delay_in_ms)); } } @@ -93,14 +103,14 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) { constexpr Error::Code result_code = Error::Code::kParseError; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(Invoke(ValidateTraceErrorCode<result_code>)); #endif { - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); TRACE_SET_RESULT(result_code); } } @@ -108,14 +118,14 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) { TEST(TraceLoggingTest, ExpectUnsetTraceIdNotSet) { StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1); #endif TraceIdHierarchy h = {kUnsetTraceId, kUnsetTraceId, kUnsetTraceId}; { - TRACE_SCOPED(TraceCategory::Value::Any, "Name", h); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h); auto ids = TRACE_HIERARCHY; EXPECT_NE(ids.current, kUnsetTraceId); @@ -130,7 +140,7 @@ TEST(TraceLoggingTest, ExpectCreationWithIdsToWork) { constexpr TraceId root = 0x84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce( @@ -141,7 +151,7 @@ TEST(TraceLoggingTest, ExpectCreationWithIdsToWork) { { TraceIdHierarchy h = {current, parent, root}; - TRACE_SCOPED(TraceCategory::Value::Any, "Name", h); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h); #if defined(ENABLE_TRACE_LOGGING) auto ids = TRACE_HIERARCHY; @@ -161,7 +171,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) { constexpr TraceId root = 0x84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll( @@ -176,7 +186,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) { { TraceIdHierarchy h = {current, parent, root}; - TRACE_SCOPED(TraceCategory::Value::Any, "Name", h); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h); #if defined(ENABLE_TRACE_LOGGING) auto ids = TRACE_HIERARCHY; EXPECT_EQ(ids.current, current); @@ -184,7 +194,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) { EXPECT_EQ(ids.root, root); #endif - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); #if defined(ENABLE_TRACE_LOGGING) ids = TRACE_HIERARCHY; EXPECT_NE(ids.current, current); @@ -200,7 +210,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) { constexpr TraceId root = 0x84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll( @@ -213,7 +223,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) { TraceIdHierarchy ids = {current, parent, root}; TRACE_SET_HIERARCHY(ids); { - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); ids = TRACE_HIERARCHY; #if defined(ENABLE_TRACE_LOGGING) EXPECT_NE(ids.current, current); @@ -231,13 +241,13 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) { } } -TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScope) { +TEST(TraceLoggingTest, ExpectHierarchyToEndAfterScope) { constexpr TraceId current = 0x32; constexpr TraceId parent = 0x47; constexpr TraceId root = 0x84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll( @@ -252,9 +262,9 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScope) { { TraceIdHierarchy ids = {current, parent, root}; - TRACE_SCOPED(TraceCategory::Value::Any, "Name", ids); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name", ids); { - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); ids = TRACE_HIERARCHY; #if defined(ENABLE_TRACE_LOGGING) EXPECT_NE(ids.current, current); @@ -278,7 +288,7 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) { constexpr TraceId root = 0x84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll( @@ -297,7 +307,7 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) { EXPECT_EQ(ids.root, root); #endif - TRACE_SCOPED(TraceCategory::Value::Any, "Name"); + TRACE_SCOPED(TraceCategory::Value::kAny, "Name"); ids = TRACE_HIERARCHY; #if defined(ENABLE_TRACE_LOGGING) EXPECT_NE(ids.current, current); @@ -310,12 +320,12 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) { TEST(TraceLoggingTest, CheckTraceAsyncStartLogsCorrectly) { StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1); #endif - { TRACE_ASYNC_START(TraceCategory::Value::Any, "Name"); } + { TRACE_ASYNC_START(TraceCategory::Value::kAny, "Name"); } } TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) { @@ -324,7 +334,7 @@ TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) { constexpr TraceId root = 84; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)) .WillOnce( @@ -336,7 +346,7 @@ TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) { TraceIdHierarchy ids = {current, parent, root}; TRACE_SET_HIERARCHY(ids); { - TRACE_ASYNC_START(TraceCategory::Value::Any, "Name"); + TRACE_ASYNC_START(TraceCategory::Value::kAny, "Name"); ids = TRACE_HIERARCHY; #if defined(ENABLE_TRACE_LOGGING) EXPECT_NE(ids.current, current); @@ -359,14 +369,13 @@ TEST(TraceLoggingTest, CheckTraceAsyncEndLogsCorrectly) { constexpr Error::Code result = Error::Code::kAgain; StrictMockLoggingPlatform platform; #if defined(ENABLE_TRACE_LOGGING) - EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any)) + EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny)) .Times(AtLeast(1)); EXPECT_CALL(platform, LogAsyncEnd(_, _, _, id, result)).Times(1); #endif - TRACE_ASYNC_END(TraceCategory::Value::Any, id, result); + TRACE_ASYNC_END(TraceCategory::Value::kAny, id, result); } } // namespace -} // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/weak_ptr.h b/chromium/third_party/openscreen/src/util/weak_ptr.h index 173c57e8f48..93b97e58c42 100644 --- a/chromium/third_party/openscreen/src/platform/impl/weak_ptr.h +++ b/chromium/third_party/openscreen/src/util/weak_ptr.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef PLATFORM_IMPL_WEAK_PTR_H_ -#define PLATFORM_IMPL_WEAK_PTR_H_ +#ifndef UTIL_WEAK_PTR_H_ +#define UTIL_WEAK_PTR_H_ #include <memory> @@ -213,4 +213,4 @@ class WeakPtrFactory { } // namespace openscreen -#endif // PLATFORM_IMPL_WEAK_PTR_H_ +#endif // UTIL_WEAK_PTR_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/weak_ptr_unittest.cc b/chromium/third_party/openscreen/src/util/weak_ptr_unittest.cc index b06574bb1c5..328802695d7 100644 --- a/chromium/third_party/openscreen/src/platform/impl/weak_ptr_unittest.cc +++ b/chromium/third_party/openscreen/src/util/weak_ptr_unittest.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "platform/impl/weak_ptr.h" +#include "util/weak_ptr.h" #include "gtest/gtest.h" @@ -15,7 +15,7 @@ class SomeClass { virtual int GetValue() const { return 42; } }; -struct SomeSubclass : public SomeClass { +struct SomeSubclass final : public SomeClass { public: ~SomeSubclass() final = default; int GetValue() const override { return 999; } |