summaryrefslogtreecommitdiff
path: root/chromium/third_party/openscreen
diff options
context:
space:
mode:
authorAllan Sandfeld Jensen <allan.jensen@qt.io>2020-07-16 11:45:35 +0200
committerAllan Sandfeld Jensen <allan.jensen@qt.io>2020-07-17 08:59:23 +0000
commit552906b0f222c5d5dd11b9fd73829d510980461a (patch)
tree3a11e6ed0538a81dd83b20cf3a4783e297f26d91 /chromium/third_party/openscreen
parent1b05827804eaf047779b597718c03e7d38344261 (diff)
downloadqtwebengine-chromium-552906b0f222c5d5dd11b9fd73829d510980461a.tar.gz
BASELINE: Update Chromium to 83.0.4103.122
Change-Id: Ie3a82f5bb0076eec2a7c6a6162326b4301ee291e Reviewed-by: Michael Brüning <michael.bruning@qt.io>
Diffstat (limited to 'chromium/third_party/openscreen')
-rw-r--r--chromium/third_party/openscreen/src/BUILD.gn55
-rw-r--r--chromium/third_party/openscreen/src/COMMITTERS1
-rw-r--r--chromium/third_party/openscreen/src/DEPS284
-rwxr-xr-xchromium/third_party/openscreen/src/PRESUBMIT.sh5
-rw-r--r--chromium/third_party/openscreen/src/README.md140
-rw-r--r--chromium/third_party/openscreen/src/build/config/BUILD.gn86
-rw-r--r--chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn54
-rw-r--r--chromium/third_party/openscreen/src/build/config/arm.gni45
-rw-r--r--chromium/third_party/openscreen/src/build/config/external_libraries.gni48
-rw-r--r--chromium/third_party/openscreen/src/build/config/sysroot.gni35
-rwxr-xr-xchromium/third_party/openscreen/src/build/scripts/dir_exists.py33
-rwxr-xr-xchromium/third_party/openscreen/src/build/scripts/install-sysroot.py172
-rwxr-xr-xchromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py67
-rw-r--r--chromium/third_party/openscreen/src/build/scripts/sysroots.json13
-rw-r--r--chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn284
-rw-r--r--chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn3
-rw-r--r--chromium/third_party/openscreen/src/cast/DEPS4
-rw-r--r--chromium/third_party/openscreen/src/cast/common/BUILD.gn80
-rw-r--r--chromium/third_party/openscreen/src/cast/common/DEPS4
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc73
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h13
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc92
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h17
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc82
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc48
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h11
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc59
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc66
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h39
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc130
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc130
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.h (renamed from chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.h)21
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/types.cc21
-rw-r--r--chromium/third_party/openscreen/src/cast/common/certificate/types.h9
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn62
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h9
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc65
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h47
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc82
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc259
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h64
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc226
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc16
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer.h13
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc14
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0bin0 -> 2272 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0bin0 -> 162 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4bin0 -> 226 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1bin0 -> 104 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2bin0 -> 104 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_protobin0 -> 88 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0cbin0 -> 90 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6bin0 -> 126 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7bin0 -> 88 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_util.cc71
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/message_util.h163
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc35
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h37
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc98
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto6
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto25
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h108
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h28
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h25
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h14
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc8
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h10
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc15
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc37
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h15
-rw-r--r--chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc74
-rw-r--r--chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS3
-rw-r--r--chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc578
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/DEPS8
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/service_info.cc243
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/service_info.h124
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc209
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc56
-rw-r--r--chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h48
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/BUILD.gn66
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/DEPS2
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc155
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h57
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc219
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc65
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h30
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc52
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h51
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc51
-rw-r--r--chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h46
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/BUILD.gn58
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/DEPS6
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc165
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h125
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc159
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc211
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h99
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc350
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc224
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h97
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc110
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc327
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h71
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc260
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc41
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/message_util.h14
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc48
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h42
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/DEPS9
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/README.md2
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc48
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h75
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc45
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h42
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/testing/DEPS4
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc87
-rw-r--r--chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h43
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn37
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS2
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc162
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h84
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc59
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h44
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc160
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h27
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h6
-rwxr-xr-xchromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh7
-rwxr-xr-xchromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh8
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc306
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/private.derbin0 -> 1192 bytes
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h125
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc18
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h18
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc12
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h11
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc77
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h54
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc16
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h15
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc99
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h68
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn44
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/DEPS8
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc32
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h71
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/main.cc480
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc378
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h205
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc224
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h123
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc492
-rw-r--r--chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h302
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/BUILD.gn145
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/DEPS3
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc208
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/answer_messages.h119
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc180
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc157
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h170
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc232
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc19
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc9
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h16
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc10
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h13
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc10
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc7
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/constants.h17
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/environment.cc22
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/environment.h71
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_collector.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc7
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_id.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/frame_id.h14
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/message_port.h38
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/message_util.h77
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc17
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/mock_environment.h30
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc5
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ntp_time.h31
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc8
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc437
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/offer_messages.h121
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc429
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h15
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_util.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_util.h8
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver.cc24
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver.h50
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc22
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h8
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc312
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_session.h166
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc536
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc166
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc10
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h9
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc5
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc7
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc8
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h22
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_time.h32
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender.cc541
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender.h318
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc274
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h196
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc616
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc32
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h10
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc10
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc37
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc1138
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/session_config.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/session_config.h9
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ssrc.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ssrc.h4
-rw-r--r--chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc6
-rw-r--r--chromium/third_party/openscreen/src/cast/test/BUILD.gn61
-rw-r--r--chromium/third_party/openscreen/src/discovery/BUILD.gn61
-rw-r--r--chromium/third_party/openscreen/src/discovery/DEPS8
-rw-r--r--chromium/third_party/openscreen/src/discovery/common/config.h56
-rw-r--r--chromium/third_party/openscreen/src/discovery/common/reporting_client.h37
-rw-r--r--chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h22
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS6
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc69
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h14
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc77
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc78
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc56
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h29
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc12
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc261
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h32
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc158
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc159
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h35
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc77
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc52
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h18
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc49
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h17
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc64
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h56
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h32
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h10
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h16
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc73
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h62
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc68
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS7
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc29
-rw-r--r--chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h23
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/DEPS5
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin20
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin7
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin10
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin8
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h28
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc113
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h129
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc255
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h150
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc355
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc128
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc362
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h171
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc465
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc612
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h162
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc349
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h12
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc2
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc169
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h9
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc12
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc61
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc52
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h28
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc36
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h17
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc363
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h164
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc164
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc377
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h43
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc762
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc33
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h12
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc99
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc142
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h68
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc334
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h181
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc296
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc29
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h5
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc47
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h107
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc15
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h75
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS5
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc32
-rw-r--r--chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h13
-rw-r--r--chromium/third_party/openscreen/src/discovery/public/DEPS5
-rw-r--r--chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h28
-rw-r--r--chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h93
-rw-r--r--chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h215
-rw-r--r--chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc335
-rw-r--r--chromium/third_party/openscreen/src/docs/style_guide.md15
-rw-r--r--chromium/third_party/openscreen/src/docs/threading.md21
-rw-r--r--chromium/third_party/openscreen/src/docs/trace_logging.md42
-rw-r--r--chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg7
-rw-r--r--chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg38
-rw-r--r--chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg24
-rw-r--r--chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg11
-rw-r--r--chromium/third_party/openscreen/src/osp/demo/osp_demo.cc24
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/BUILD.gn2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn1
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc46
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h36
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc126
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h44
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc15
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/internal_services.cc66
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/internal_services.h31
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc6
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h14
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc40
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h44
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc49
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc4
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc4
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc13
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc15
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc6
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc112
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc14
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc47
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc3
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h12
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc37
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc4
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc4
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc7
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h9
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc36
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc48
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h17
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc36
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h13
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc7
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h9
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc36
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc7
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h7
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc20
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h16
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc16
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h9
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc2
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc35
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc53
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h57
-rw-r--r--chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc24
-rw-r--r--chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/public/client_config.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/message_demuxer.cc2
-rw-r--r--chromium/third_party/openscreen/src/osp/public/message_demuxer.h18
-rw-r--r--chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc178
-rw-r--r--chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h2
-rw-r--r--chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/server_config.h4
-rw-r--r--chromium/third_party/openscreen/src/osp/public/service_info.cc9
-rw-r--r--chromium/third_party/openscreen/src/osp/public/service_info.h5
-rw-r--r--chromium/third_party/openscreen/src/osp/public/service_listener.h3
-rw-r--r--chromium/third_party/openscreen/src/osp/public/service_publisher.h5
-rw-r--r--chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/BUILD.gn101
-rw-r--r--chromium/third_party/openscreen/src/platform/api/DEPS13
-rw-r--r--chromium/third_party/openscreen/src/platform/api/logging.h13
-rw-r--r--chromium/third_party/openscreen/src/platform/api/network_interface.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc8
-rw-r--r--chromium/third_party/openscreen/src/platform/api/task_runner.h13
-rw-r--r--chromium/third_party/openscreen/src/platform/api/time.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/time_unittest.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/tls_connection.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/tls_connection.h13
-rw-r--r--chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h4
-rw-r--r--chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h4
-rw-r--r--chromium/third_party/openscreen/src/platform/api/udp_socket.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/api/udp_socket.h19
-rw-r--r--chromium/third_party/openscreen/src/platform/base/DEPS7
-rw-r--r--chromium/third_party/openscreen/src/platform/base/error.cc18
-rw-r--r--chromium/third_party/openscreen/src/platform/base/error.h103
-rw-r--r--chromium/third_party/openscreen/src/platform/base/interface_info.cc45
-rw-r--r--chromium/third_party/openscreen/src/platform/base/interface_info.h17
-rw-r--r--chromium/third_party/openscreen/src/platform/base/ip_address.cc331
-rw-r--r--chromium/third_party/openscreen/src/platform/base/ip_address.h82
-rw-r--r--chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc81
-rw-r--r--chromium/third_party/openscreen/src/platform/base/location.cc7
-rw-r--r--chromium/third_party/openscreen/src/platform/base/socket_state.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/base/tls_connect_options.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/base/tls_credentials.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/base/tls_credentials.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/base/tls_listen_options.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc62
-rw-r--r--chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h39
-rw-r--r--chromium/third_party/openscreen/src/platform/base/trace_logging_types.h18
-rw-r--r--chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc25
-rw-r--r--chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h15
-rw-r--r--chromium/third_party/openscreen/src/platform/base/udp_packet.cc6
-rw-r--r--chromium/third_party/openscreen/src/platform/base/udp_packet.h11
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/logging.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/logging_posix.cc10
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/network_interface.cc44
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/network_interface.h26
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc26
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc17
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc129
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h80
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc56
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h27
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc34
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h8
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc8
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc61
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h17
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc26
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h6
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc7
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/stream_socket.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc81
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h10
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/task_runner.cc52
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/task_runner.h13
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc20
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc4
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/time.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/time_unittest.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/timeval_posix.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc168
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h9
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc148
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h40
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc143
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h41
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc207
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc50
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h32
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc119
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc18
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h4
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc8
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h2
-rw-r--r--chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc10
-rw-r--r--chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn29
-rwxr-xr-xchromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py58
-rw-r--r--chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni193
-rwxr-xr-xchromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py86
-rw-r--r--chromium/third_party/openscreen/src/testing/util/BUILD.gn17
-rw-r--r--chromium/third_party/openscreen/src/testing/util/read_file.cc34
-rw-r--r--chromium/third_party/openscreen/src/testing/util/read_file.h18
-rw-r--r--chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn2
-rw-r--r--chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn12
-rw-r--r--chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn34
-rw-r--r--chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn7
-rw-r--r--chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc3
-rw-r--r--chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn4
-rw-r--r--chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn44
-rw-r--r--chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn19
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn14
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt65
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/README.md7
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc858
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h322
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc87
-rw-r--r--chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h50
-rw-r--r--chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn2
-rw-r--r--chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn7
-rw-r--r--chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn41
-rwxr-xr-xchromium/third_party/openscreen/src/tools/cl-format.sh32
-rwxr-xr-xchromium/third_party/openscreen/src/tools/clang/scripts/update.py1025
-rwxr-xr-xchromium/third_party/openscreen/src/tools/download-clang-update-script.py66
-rwxr-xr-xchromium/third_party/openscreen/src/tools/install-build-tools.sh36
-rw-r--r--chromium/third_party/openscreen/src/util/BUILD.gn27
-rw-r--r--chromium/third_party/openscreen/src/util/alarm.cc27
-rw-r--r--chromium/third_party/openscreen/src/util/alarm.h36
-rw-r--r--chromium/third_party/openscreen/src/util/alarm_unittest.cc81
-rw-r--r--chromium/third_party/openscreen/src/util/big_endian.h67
-rw-r--r--chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc138
-rw-r--r--chromium/third_party/openscreen/src/util/crypto/certificate_utils.h39
-rw-r--r--chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc25
-rw-r--r--chromium/third_party/openscreen/src/util/crypto/digest_sign.cc31
-rw-r--r--chromium/third_party/openscreen/src/util/crypto/digest_sign.h23
-rw-r--r--chromium/third_party/openscreen/src/util/hashing.h49
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_reader.cc36
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_reader.h33
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_serialization.cc (renamed from chromium/third_party/openscreen/src/util/json/json_writer.cc)48
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_serialization.h25
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_serialization_unittest.cc (renamed from chromium/third_party/openscreen/src/util/json/json_reader_unittest.cc)36
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_value.cc43
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_value.h28
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_value_unittest.cc55
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_writer.h34
-rw-r--r--chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc32
-rw-r--r--chromium/third_party/openscreen/src/util/logging.h20
-rw-r--r--chromium/third_party/openscreen/src/util/operation_loop.h2
-rw-r--r--chromium/third_party/openscreen/src/util/saturate_cast.h82
-rw-r--r--chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc183
-rw-r--r--chromium/third_party/openscreen/src/util/serial_delete_ptr.h33
-rw-r--r--chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc4
-rw-r--r--chromium/third_party/openscreen/src/util/simple_fraction.cc69
-rw-r--r--chromium/third_party/openscreen/src/util/simple_fraction.h45
-rw-r--r--chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc98
-rw-r--r--chromium/third_party/openscreen/src/util/stringprintf.cc21
-rw-r--r--chromium/third_party/openscreen/src/util/stringprintf.h8
-rw-r--r--chromium/third_party/openscreen/src/util/stringprintf_unittest.cc24
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging.h40
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging/macro_support.h6
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc43
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h54
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc22
-rw-r--r--chromium/third_party/openscreen/src/util/trace_logging_unittest.cc83
-rw-r--r--chromium/third_party/openscreen/src/util/weak_ptr.h (renamed from chromium/third_party/openscreen/src/platform/impl/weak_ptr.h)6
-rw-r--r--chromium/third_party/openscreen/src/util/weak_ptr_unittest.cc (renamed from chromium/third_party/openscreen/src/platform/impl/weak_ptr_unittest.cc)4
579 files changed, 31941 insertions, 7057 deletions
diff --git a/chromium/third_party/openscreen/src/BUILD.gn b/chromium/third_party/openscreen/src/BUILD.gn
index 45c89852612..882422618f7 100644
--- a/chromium/third_party/openscreen/src/BUILD.gn
+++ b/chromium/third_party/openscreen/src/BUILD.gn
@@ -5,31 +5,43 @@
import("//build_overrides/build.gni")
import("osp/build/config/services.gni")
+declare_args() {
+ # Set to true to force building the standalone receiver on Mac. It's currently
+ # disabled due to build bot struggles, but works fine on local, recent clang
+ # installations.
+ # TODO(crbug.com/openscreen/86): Remove when the Mac bots have been upgraded.
+ force_build_standalone_receiver = false
+}
+
# All compilable non-test targets in the repository (both executables and
# source_sets).
group("gn_all") {
+ testonly = true
+
deps = [
"cast/common:certificate",
"cast/common:channel",
+ "cast/common:public",
+ "cast/receiver:channel",
"cast/sender:channel",
"cast/streaming:receiver",
"cast/streaming:sender",
+ "discovery:common",
"discovery:dnssd",
"discovery:mdns",
+ "discovery:public",
"osp",
"osp/msgs",
"platform",
"third_party/abseil",
"third_party/boringssl",
"third_party/jsoncpp",
+ "third_party/mozilla",
"third_party/tinycbor",
+ "tools/cddl($host_toolchain)",
"util",
]
- if (current_toolchain == host_toolchain) {
- deps += [ "tools/cddl" ]
- }
-
if (use_mdns_responder) {
deps += [ "osp/impl/discovery/mdns:mdns_demo" ]
}
@@ -37,8 +49,8 @@ group("gn_all") {
if (use_chromium_quic) {
deps += [
"third_party/chromium_quic",
- "third_party/chromium_quic:quic_demo_client",
"third_party/chromium_quic:quic_demo_server",
+ "third_party/chromium_quic:quic_streaming_playback_controller",
]
}
@@ -48,14 +60,21 @@ group("gn_all") {
if (!build_with_chromium) {
deps += [
- "third_party/protobuf:protoc",
+ "third_party/protobuf:protoc($host_toolchain)",
"third_party/zlib",
]
+ if (is_posix) {
+ deps += [ "cast/test:make_crl_tests($host_toolchain)" ]
+ }
+
# TODO(crbug.com/openscreen/86): Build for Mac too once the mac buildbot
# compiler is upgraded.
- if (!is_mac) {
- deps += [ "cast/standalone_receiver:cast_receiver" ]
+ if (!is_mac || force_build_standalone_receiver) {
+ deps += [
+ "cast/standalone_receiver:cast_receiver",
+ "cast/standalone_sender:cast_sender",
+ ]
}
}
}
@@ -64,8 +83,10 @@ source_set("openscreen_unittests_all") {
testonly = true
public_deps = [
"cast/common:unittests",
+ "cast/receiver:unittests",
"cast/sender:unittests",
"cast/streaming:unittests",
+ "cast/test:unittests",
"discovery:unittests",
"osp:unittests",
"osp/msgs:unittests",
@@ -93,3 +114,21 @@ if (!build_with_chromium) {
]
}
}
+
+if (!build_with_chromium && is_posix) {
+ source_set("e2e_tests_all") {
+ testonly = true
+ public_deps = [
+ "cast/common:discovery_e2e_test",
+ "cast/test:e2e_tests",
+ "third_party/googletest:gtest_main",
+ ]
+ }
+
+ executable("e2e_tests") {
+ testonly = true
+ deps = [
+ ":e2e_tests_all",
+ ]
+ }
+}
diff --git a/chromium/third_party/openscreen/src/COMMITTERS b/chromium/third_party/openscreen/src/COMMITTERS
index 48b2a70ffb9..770a6ccd23f 100644
--- a/chromium/third_party/openscreen/src/COMMITTERS
+++ b/chromium/third_party/openscreen/src/COMMITTERS
@@ -5,6 +5,5 @@ btolsch@chromium.org
# Additional reviewers
jopbha@chromium.org
miu@chromium.org
-pthatcher@chromium.org
rwkeane@chromium.org
yakimakha@chromium.org
diff --git a/chromium/third_party/openscreen/src/DEPS b/chromium/third_party/openscreen/src/DEPS
index 9955ada4180..e4a34906c7e 100644
--- a/chromium/third_party/openscreen/src/DEPS
+++ b/chromium/third_party/openscreen/src/DEPS
@@ -8,140 +8,188 @@
# to list the dependency's destination directory.
use_relative_paths = True
+use_relative_hooks = True
vars = {
- 'boringssl_git': 'https://boringssl.googlesource.com',
- 'chromium_git': 'https://chromium.googlesource.com',
+ 'boringssl_git': 'https://boringssl.googlesource.com',
+ 'chromium_git': 'https://chromium.googlesource.com',
- # TODO(jophba): move to googlesource external for github repos.
- 'github': 'https://github.com',
+ # TODO(jophba): move to googlesource external for github repos.
+ 'github': 'https://github.com',
- # NOTE: Strangely enough, this will be overridden by any _parent_ DEPS, so
- # in Chromium it will correctly be True.
- 'build_with_chromium': False,
+ # NOTE: Strangely enough, this will be overridden by any _parent_ DEPS, so
+ # in Chromium it will correctly be True.
+ 'build_with_chromium': False,
- 'gn_version': 'git_revision:0790d3043387c762a6bacb1ae0a9ebe883188ab2',
- 'checkout_chromium_quic_boringssl': False,
+ 'checkout_chromium_quic_boringssl': False,
- # By default, do not check out openscreen/cast. This can be overridden
- # by custom_vars in .gclient.
- 'checkout_openscreen_cast_internal': False
+ # Needed to download additional clang binaries for processing coverage data
+ # (from binaries with GN arg `use_coverage=true`).
+ 'checkout_clang_coverage_tools': False,
}
deps = {
- 'cast/internal': {
- 'url': 'https://chrome-internal.googlesource.com/openscreen/cast.git' +
- '@' + '703984f9d1674c2cfc259904a5a7fba4990cca4b',
- 'condition': 'checkout_openscreen_cast_internal',
- },
-
- 'buildtools': {
- 'url': Var('chromium_git')+ '/chromium/src/buildtools' +
- '@' + '140e4d7c45ffb55ce5dc4d11a0c3938363cd8257',
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/protobuf/src': {
- 'url': Var('chromium_git') +
- '/external/github.com/protocolbuffers/protobuf.git' +
- '@' + 'd09d649aea36f02c03f8396ba39a8d4db8a607e4', # version 3.10.1
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/zlib/src': {
- 'url': Var('github') +
- '/madler/zlib.git' +
- '@' + 'cacf7f1d4e3d44d871b605da3b647f07d718623f', # version 1.2.11
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/jsoncpp/src': {
- 'url': Var('chromium_git') +
- '/external/github.com/open-source-parsers/jsoncpp.git' +
- '@' + '2eb20a938c454411c1d416caeeb2a6511daab5cb', # version 1.9.0
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/googletest/src': {
- 'url': Var('chromium_git') +
- '/external/github.com/google/googletest.git' +
- '@' + '8697709e0308af4cd5b09dc108480804e5447cf0',
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/mDNSResponder/src': {
- 'url': Var('github') + '/jevinskie/mDNSResponder.git' +
- '@' + '2942dde61f920fbbf96ff9a3840567ebbe7cb1b6',
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/boringssl/src': {
- 'url' : Var('boringssl_git') + '/boringssl.git' +
- '@' + '6410e18e9190b6b0c71955119fbf3cae1b9eedb7',
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/chromium_quic/src': {
- 'url': Var('chromium_git') + '/openscreen/quic.git' +
- '@' + 'b73bd98ac9eaedf01a732b1933f97112cf247d93',
- 'condition': 'not build_with_chromium',
- },
-
- 'third_party/tinycbor/src':
- Var('chromium_git') + '/external/github.com/intel/tinycbor.git' +
- '@' + '755f9ef932f9830a63a712fd2ac971d838b131f1',
-
- 'third_party/abseil/src': {
- 'url': Var('chromium_git') +
- '/external/github.com/abseil/abseil-cpp.git' +
- '@' + '20de2db748ca0471cfb61cb53e813dd12938c12b',
- 'condition': 'not build_with_chromium',
- },
+ # NOTE: This commit hash here references a repository/branch that is a mirror
+ # of the commits to the buildtools directory in the Chromium repository. This
+ # should be regularly updated with the tip of the MIRRORED master branch,
+ # found here:
+ # https://chromium.googlesource.com/chromium/src/buildtools/+/refs/heads/master.
+ 'buildtools': {
+ 'url': Var('chromium_git')+ '/chromium/src/buildtools' +
+ '@' + '8d2132841536523249669813b928e29144d487f9',
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/protobuf/src': {
+ 'url': Var('chromium_git') +
+ '/external/github.com/protocolbuffers/protobuf.git' +
+ '@' + 'd09d649aea36f02c03f8396ba39a8d4db8a607e4', # version 3.10.1
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/zlib/src': {
+ 'url': Var('github') +
+ '/madler/zlib.git' +
+ '@' + 'cacf7f1d4e3d44d871b605da3b647f07d718623f', # version 1.2.11
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/jsoncpp/src': {
+ 'url': Var('chromium_git') +
+ '/external/github.com/open-source-parsers/jsoncpp.git' +
+ '@' + '2eb20a938c454411c1d416caeeb2a6511daab5cb', # version 1.9.0
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/googletest/src': {
+ 'url': Var('chromium_git') +
+ '/external/github.com/google/googletest.git' +
+ '@' + '8697709e0308af4cd5b09dc108480804e5447cf0',
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/mDNSResponder/src': {
+ 'url': Var('github') + '/jevinskie/mDNSResponder.git' +
+ '@' + '2942dde61f920fbbf96ff9a3840567ebbe7cb1b6',
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/boringssl/src': {
+ 'url' : Var('boringssl_git') + '/boringssl.git' +
+ '@' + '6410e18e9190b6b0c71955119fbf3cae1b9eedb7',
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/chromium_quic/src': {
+ 'url': Var('chromium_git') + '/openscreen/quic.git' +
+ '@' + 'd2363edc9f2a8561c9d02e836262f2d03de2d6e1',
+ 'condition': 'not build_with_chromium',
+ },
+
+ 'third_party/tinycbor/src':
+ Var('chromium_git') + '/external/github.com/intel/tinycbor.git' +
+ '@' + '755f9ef932f9830a63a712fd2ac971d838b131f1',
+
+ 'third_party/abseil/src': {
+ 'url': Var('chromium_git') +
+ '/external/github.com/abseil/abseil-cpp.git' +
+ '@' + '20de2db748ca0471cfb61cb53e813dd12938c12b',
+ 'condition': 'not build_with_chromium',
+ },
+ 'third_party/libfuzzer/src': {
+ 'url': Var('chromium_git') +
+ '/chromium/llvm-project/compiler-rt/lib/fuzzer.git' +
+ '@' + 'debe7d2d1982e540fbd6bd78604bf001753f9e74',
+ 'condition': 'not build_with_chromium',
+ },
}
+hooks = [
+ {
+ 'name': 'clang_update_script',
+ 'pattern': '.',
+ 'condition': 'not build_with_chromium',
+ 'action': [ 'python', 'tools/download-clang-update-script.py',
+ '--output', 'tools/clang/scripts/update.py' ],
+ # NOTE: This file appears in .gitignore, as it is not a part of the
+ # openscreen repo.
+ },
+ {
+ 'name': 'update_clang',
+ 'pattern': '.',
+ 'condition': 'not build_with_chromium',
+ 'action': [ 'python', 'tools/clang/scripts/update.py' ],
+ },
+ {
+ 'name': 'clang_coverage_tools',
+ 'pattern': '.',
+ 'condition': 'not build_with_chromium and checkout_clang_coverage_tools',
+ 'action': ['python', 'tools/clang/scripts/update.py',
+ '--package=coverage_tools'],
+ },
+ {
+ 'name': 'clang_format_linux64',
+ 'pattern': '.',
+ 'action': [ 'download_from_google_storage.py', '--no_resume', '--no_auth',
+ '--bucket', 'chromium-clang-format',
+ '-s', 'buildtools/linux64/clang-format.sha1' ],
+ 'condition': 'host_os == "linux" and not build_with_chromium',
+ },
+ {
+ 'name': 'clang_format_mac',
+ 'pattern': '.',
+ 'action': [ 'download_from_google_storage.py', '--no_resume', '--no_auth',
+ '--bucket', 'chromium-clang-format',
+ '-s', 'buildtools/mac/clang-format.sha1' ],
+ 'condition': 'host_os == "mac" and not build_with_chromium',
+ },
+]
+
recursedeps = [
- 'third_party/chromium_quic/src',
- 'buildtools',
+ 'third_party/chromium_quic/src',
+ 'buildtools',
]
include_rules = [
- '+build/config/features.h',
- '+util',
- '+platform/api',
- '+platform/base',
- '+platform/test',
- '+third_party',
-
- # Don't include abseil from the root so the path can change via include_dirs
- # rules when in Chromium.
- '-third_party/abseil',
-
- # Abseil whitelist.
- '+absl/algorithm/container.h',
- '+absl/base/thread_annotations.h',
- '+absl/hash/hash.h',
- '+absl/strings/ascii.h',
- '+absl/strings/match.h',
- '+absl/strings/numbers.h',
- '+absl/strings/str_cat.h',
- '+absl/strings/str_join.h',
- '+absl/strings/str_split.h',
- '+absl/strings/string_view.h',
- '+absl/strings/substitute.h',
- '+absl/types/optional.h',
- '+absl/types/span.h',
- '+absl/types/variant.h',
-
- # Similar to abseil, don't include boringssl using root path. Instead,
- # explicitly allow 'openssl' where needed.
- '-third_party/boringssl',
-
- # Test framework includes.
- "-third_party/googletest",
- "+gtest",
- "+gmock",
+ '+build/config/features.h',
+ '+util',
+ '+platform/api',
+ '+platform/base',
+ '+platform/test',
+ '+testing/util',
+ '+third_party',
+
+ # Don't include abseil from the root so the path can change via include_dirs
+ # rules when in Chromium.
+ '-third_party/abseil',
+
+ # Abseil whitelist.
+ '+absl/algorithm/container.h',
+ '+absl/base/thread_annotations.h',
+ '+absl/hash/hash.h',
+ '+absl/strings/ascii.h',
+ '+absl/strings/match.h',
+ '+absl/strings/numbers.h',
+ '+absl/strings/str_cat.h',
+ '+absl/strings/str_join.h',
+ '+absl/strings/str_replace.h',
+ '+absl/strings/str_split.h',
+ '+absl/strings/string_view.h',
+ '+absl/strings/substitute.h',
+ '+absl/types/optional.h',
+ '+absl/types/span.h',
+ '+absl/types/variant.h',
+
+ # Similar to abseil, don't include boringssl using root path. Instead,
+ # explicitly allow 'openssl' where needed.
+ '-third_party/boringssl',
+
+ # Test framework includes.
+ "-third_party/googletest",
+ "+gtest",
+ "+gmock",
]
skip_child_includes = [
- 'third_party/chromium_quic',
+ 'third_party/chromium_quic',
]
diff --git a/chromium/third_party/openscreen/src/PRESUBMIT.sh b/chromium/third_party/openscreen/src/PRESUBMIT.sh
index 8181f11550f..6c4e9009f74 100755
--- a/chromium/third_party/openscreen/src/PRESUBMIT.sh
+++ b/chromium/third_party/openscreen/src/PRESUBMIT.sh
@@ -59,11 +59,6 @@ if [[ "$invoker" != 'python' ]]; then
echo "This shouldn't be invoked directly, please use \`git cl presubmit\`."
fi
-# TODO(jophba): check in a better fix for the build bots.
-if command -v clang-format &> /dev/null; then
- tools/install-build-tools.sh &> /dev/null
-fi
-
for f in $(git diff --name-only --diff-filter=d origin/master); do
# Skip third party files, except our custom BUILD.gns
if [[ $f =~ third_party/[^\/]*/src ]]; then
diff --git a/chromium/third_party/openscreen/src/README.md b/chromium/third_party/openscreen/src/README.md
index 17fb84378c8..a03d6d419ab 100644
--- a/chromium/third_party/openscreen/src/README.md
+++ b/chromium/third_party/openscreen/src/README.md
@@ -26,22 +26,28 @@ lint` and `git cl upload.`
## Checking out code
-From the parent directory of where you want the openscreen checkout, configure
-`gclient` and check out openscreen with the following commands:
+From the parent directory of where you want the openscreen checkout (e.g.,
+`~/my_project_dir`), configure `gclient` and check out openscreen with the
+following commands:
```bash
+ cd ~/my_project_dir
gclient config https://chromium.googlesource.com/openscreen
gclient sync
```
-Now, you should have `openscreen/` repository checked out, with all dependencies
-checked out to their appropriate revisions.
+The first `gclient` command will create a default .gclient file in
+`~/my_project_dir` that describes how to pull down the `openscreen` repository.
+The second command creates an `openscreen/` subdirectory, downloads the source
+code, all third-party dependencies, and the toolchain needed to build things;
+and at their appropriate revisions.
## Syncing your local checkout
To update your local checkout from the openscreen master repository, just run
```bash
+ cd ~/my_project_dir/openscreen
git pull
gclient sync
```
@@ -51,24 +57,23 @@ dependencies that have changed.
# Build setup
-## Installing build dependencies
-
-The following tools are required for building:
+The following are the main tools are required for development/builds:
- Build file generator: `gn`
- - Code formatter (optional): `clang-format`
+ - Code formatter: `clang-format`
- Builder: `ninja` ([GitHub releases](https://github.com/ninja-build/ninja/releases))
+ - Compiler/Linker: `clang` (installed by default) or `gcc` (installed by you)
-`clang-format` and `ninja` can be downloaded to `buildtools/<platform>` root by
-running `./tools/install-build-tools.sh`.
-
-`clang-format` is only used for presubmit checks and optionally used on
-generated code from the CDDL tool.
+All of these--except `gcc` as noted above--are automatically downloaded/updated
+for the Linux and Mac environments via `gclient sync` as described above. The
+first two are installed into `buildtools/<platform>/`.
-`gn` will be installed in `buildtools/<platform>/` automatically by `gclient sync`.
+Mac only: XCode must be installed on the system, to link against its frameworks.
-You also need to ensure that you have the compiler and its toolchain dependencies.
-Currently, both Linux and Mac OS X build configurations use clang by default.
+`clang-format` is used for maintaining consistent coding style, but it is not a
+complete replacement for adhering to Chromium/Google C++ style (that's on you!).
+The presubmit script will sanity-check that it has been run on all new/changed
+code.
## Linux clang
@@ -153,6 +158,11 @@ the working directory for the build. So the same could be done as follows:
After editing a file, only `ninja` needs to be rerun, not `gn`. If you have
edited a `BUILD.gn` file, `ninja` will re-run `gn` for you.
+Unless you like to wait longer than necessary for builds to complete, run
+`autoninja` instead of `ninja`, which takes the same command-line arguments.
+This will automatically parallelize the build for your system, depending on
+number of processor cores, RAM, etc.
+
For details on running `demo`, see its [README.md](demo/README.md).
## Building other targets
@@ -174,6 +184,25 @@ the build flags available.
./out/Default/unittests
```
+## Building and running fuzzers
+
+In order to build fuzzers, you need the GN arg `use_libfuzzer=true`. It's also
+recommended to build with `is_asan=true` to catch additional problems. Building
+and running then might look like:
+```bash
+ gn gen out/libfuzzer --args="use_libfuzzer=true is_asan=true is_debug=false"
+ ninja -C out/libfuzzer some_fuzz_target
+ out/libfuzzer/some_fuzz_target <args> <corpus_dir> [additional corpus dirs]
+```
+
+The arguments to the fuzzer binary should be whatever is listed in the GN target
+description (e.g. `-max_len=1500`). These arguments may be automatically
+scraped by Chromium's ClusterFuzz tool when it runs fuzzers, but they are not
+built into the target. You can also look at the file
+`out/libfuzzer/some_fuzz_target.options` for what arguments should be used. The
+`corpus_dir` is listed as `seed_corpus` in the GN definition of the fuzzer
+target.
+
# Continuous build and try jobs
openscreen uses [LUCI builders](https://ci.chromium.org/p/openscreen/builders)
@@ -213,11 +242,13 @@ review tool) and is recommended for pushing patches for review. Once you have
committed changes locally, simply run:
```bash
+ git cl format
git cl upload
```
-This will run our `PRESUBMIT.sh` script to check style, and if it passes, a new
-code review will be posted on `chromium-review.googlesource.com`.
+The first command will will auto-format the code changes. Then, the second
+command runs the `PRESUBMIT.sh` script to check style and, if it passes, a
+newcode review will be posted on `chromium-review.googlesource.com`.
If you make additional commits to your local branch, then running `git cl
upload` again in the same branch will merge those commits into the ongoing
@@ -256,3 +287,76 @@ After your patch has received one or more LGTM commit it by clicking the
`SUBMIT` button (or, confusingly, `COMMIT QUEUE +2`) in Gerrit. This will run
your patch through the builders again before committing to the main openscreen
repository.
+
+<!-- TODO(mfoltz): split up README.md into more manageable files. -->
+## Working with ARM/ARM64/the Raspberry PI
+
+openscreen supports cross compilation for both arm32 and arm64 platforms, by
+using the `gn args` parameter `target_cpu="arm"` or `target_cpu="arm64"`
+respectively. Note that quotes are required around the target arch value.
+
+Setting an arm(64) target_cpu causes GN to pull down a sysroot from openscreen's
+public cloud storage bucket. Google employees may update the sysroots stored
+by requesting access to the Open Screen pantheon project and uploading a new
+tar.xz to the openscreen-sysroots bucket.
+
+NOTE: The "arm" image is taken from Chromium's debian arm image, however it has
+been manually patched to include support for libavcodec and libsdl2. To update
+this image, the new image must be manually patched to include the necessary
+header and library dependencies. Note that if the versions of libavcodec and
+libsdl2 are too out of sync from the copies in the sysroot, compilation will
+succeed, but you may experience issues decoding content.
+
+To install the last known good version of the libavcodec and libsdl packages
+on a Raspberry Pi, you can run the following command:
+
+```bash
+sudo ./cast/standalone_receiver/install_demo_deps_raspian.sh
+```
+
+NOTE: until [Issue 106](http://crbug.com/openscreen/106) is resolved, you may
+experience issues streaming to a Raspberry Pi if multiple network interfaces
+(e.g. WiFi + Ethernet) are enabled. The workaround is to disable either the WiFi
+or ethernet connection.
+
+## Code Coverage
+
+Code coverage can be checked using clang's source-based coverage tools. You
+must use the GN argument `use_coverage=true`. It's recommended to do this in a
+separate output directory since the added instrumentation will affect
+performance and generate an output file every time a binary is run. You can
+read more about this in [clang's
+documentation](http://clang.llvm.org/docs/SourceBasedCodeCoverage.html) but the
+bare minimum steps are also outlined below. You will also need to download the
+pre-built clang coverage tools, which are not downloaded by default. The
+easiest way to do this is to set a custom variable in your `.gclient` file.
+Under the "openscreen" solution, add:
+```python
+ "custom_vars": {
+ "checkout_clang_coverage_tools": True,
+ },
+```
+then run `gclient runhooks`. You can also run the python command from the
+`clang_coverage_tools` hook in `//DEPS` yourself or even download the tools
+manually
+([link](https://storage.googleapis.com/chromium-browser-clang-staging/)).
+
+Once you have your GN directory (we'll call it `out/coverage`) and have
+downloaded the tools, do the following to generate an HTML coverage report:
+```bash
+out/coverage/openscreen_unittests
+third_party/llvm-build/Release+Asserts/bin/llvm-profdata merge -sparse default.profraw -o foo.profdata
+third_party/llvm-build/Release+Asserts/bin/llvm-cov show out/coverage/openscreen_unittests -instr-profile=foo.profdata -format=html -output-dir=<out dir> [filter paths]
+```
+There are a few things to note here:
+ - `default.profraw` is generated by running the instrumented code, but
+ `foo.profdata` can be any path you want.
+ - `<out dir>` should be an empty directory for placing the generated HTML
+ files. You can view the report at `<out dir>/index.html`.
+ - `[filter paths]` is a list of paths to which you want to limit the coverage
+ report. For example, you may want to limit it to cast/ or even
+ cast/streaming/. If this list is empty, all data will be in the report.
+
+The same process can be used to check the coverage of a fuzzer's corpus. Just
+add `-runs=0` to the fuzzer arguments to make sure it only runs the existing
+corpus then exits.
diff --git a/chromium/third_party/openscreen/src/build/config/BUILD.gn b/chromium/third_party/openscreen/src/build/config/BUILD.gn
index 27f29ed5dd2..ec988061c61 100644
--- a/chromium/third_party/openscreen/src/build/config/BUILD.gn
+++ b/chromium/third_party/openscreen/src/build/config/BUILD.gn
@@ -2,6 +2,9 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
+import("//build/config/arm.gni")
+import("//build/config/sysroot.gni")
+
declare_args() {
# Enable OSP_DCHECKs to be compiled in, even if it's not a debug build.
dcheck_always_on = false
@@ -11,8 +14,16 @@ declare_args() {
# Enable thread sanitizer.
is_tsan = false
+
+ # Must be enabled for fuzzing targets.
+ use_libfuzzer = false
+
+ # Enables clang's source-based coverage (requires is_clang=true).
+ use_coverage = false
}
+assert(!use_coverage || is_clang)
+
config("compiler_defaults") {
cflags = []
if (is_posix && !is_mac) {
@@ -35,6 +46,31 @@ config("compiler_defaults") {
}
}
+config("compiler_cpu_abi") {
+ cflags = []
+ ldflags = []
+
+ if (current_cpu == "x64") {
+ # These are explicitly specified in case of cross-compiling.
+ cflags += [ "-m64" ]
+ ldflags += [ "-m64" ]
+ } else if (current_cpu == "x86") {
+ cflags += [ "-m32" ]
+ ldflags += [ "-m32" ]
+ } else if (current_cpu == "arm") {
+ cflags += [
+ "--target=arm-linux-gnueabihf",
+ "-march=$arm_arch",
+ "-mfloat-abi=$arm_float_abi",
+ "-mtune=$arm_tune",
+ ]
+ ldflags += [ "--target=arm-linux-gnueabihf" ]
+ } else if (current_cpu == "arm64") {
+ cflags += [ "--target=aarch64-linux-gnu" ]
+ ldflags += [ "--target=aarch64-linux-gnu" ]
+ }
+}
+
config("no_exceptions") {
# -fno-exceptions causes the compiler to choose the implementation of the STL
# that uses abort() calls instead of throws, as well as issue compile errors
@@ -166,4 +202,54 @@ config("default_sanitizers") {
cflags += [ "-fsanitize=thread" ]
ldflags += [ "-fsanitize=thread" ]
}
+
+ if (use_libfuzzer) {
+ cflags += [ "-fsanitize=fuzzer-no-link" ]
+ if (!is_asan) {
+ ldflags += [ "-fsanitize=address" ]
+ }
+ }
+}
+
+config("default_coverage") {
+ cflags = []
+ ldflags = []
+
+ if (use_coverage) {
+ cflags += [
+ "-fprofile-instr-generate",
+ "-fcoverage-mapping",
+ ]
+ ldflags += [ "-fprofile-instr-generate" ]
+ }
+}
+
+config("sysroot_runtime_libraries") {
+ if (sysroot != "") {
+ # As other sysroot CPU targets get added, they should be checked here.
+ assert(current_cpu == "arm" || current_cpu == "arm64")
+ assert(is_clang)
+ sysroot_path = rebase_path(sysroot, root_build_dir)
+ flags = [ "--sysroot=" + sysroot_path ]
+ hash = exec_script("//build/scripts/install-sysroot.py",
+ [
+ "--print-hash",
+ "$current_cpu",
+ ],
+ "trim string",
+ [ "//build/scripts/sysroots.json" ])
+
+ # GN uses this to know that the sysroot is "dirty"
+ defines = [ "SYSROOT_HASH=$hash" ]
+ ldflags = flags
+ cflags = flags
+
+ ld_paths = exec_script("//build/scripts/sysroot_ld_path.py",
+ [ sysroot_path ],
+ "list lines")
+ foreach(ld_path, ld_paths) {
+ ld_path = rebase_path(ld_path, root_build_dir)
+ ldflags += [ "-L" + ld_path ]
+ }
+ }
}
diff --git a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn
index 359ae15c6fb..4e69143f537 100644
--- a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn
+++ b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn
@@ -68,33 +68,54 @@ declare_args() {
# gcc compiler on Linux instead, set is_gcc to true.
is_gcc = false
clang_base_path = default_clang_base_path
+
+ # This would not normally be set as a build argument, but rather is used as a
+ # default value during the first parse of this config. All other toolchains
+ # that cause this file to be re-parsed will already have this set. For
+ # further explanation, see
+ # https://gn.googlesource.com/gn/+/refs/heads/master/docs/reference.md#toolchain-overview
+ host_toolchain = ""
}
declare_args() {
is_clang = !is_gcc
}
-# We need to ensure that clang is pulled down using the update script. In
-# Chromium, this is done with a gclient hook, but we can just call
-if (is_clang) {
- exec_script("//tools/clang/scripts/update.py")
-}
-
# ==============================================================================
# TOOLCHAIN SETUP
# ==============================================================================
#
-# Here we set the default toolchain. Currently only Mac and POSIX are defined.
-host_toolchain = ""
-if (current_os == "chromeos" || current_os == "linux") {
- host_toolchain = "//build/toolchain/linux:linux"
-} else if (current_os == "mac") {
- host_toolchain = "//build/toolchain/mac:clang"
+# Here we set the host and default toolchains. Currently only Mac and POSIX are
+# defined.
+if (host_toolchain == "") {
+ if (current_os == "chromeos" || current_os == "linux") {
+ if (is_clang) {
+ host_toolchain = "//build/toolchain/linux:clang_$host_cpu"
+ } else {
+ host_toolchain = "//build/toolchain/linux:gcc_$host_cpu"
+ }
+ } else if (current_os == "mac") {
+ host_toolchain = "//build/toolchain/mac:clang"
+ } else {
+ # TODO(miu): Windows, and others.
+ assert(false, "Toolchain for current_os is not defined.")
+ }
+}
+
+_default_toolchain = ""
+if (target_os == "chromeos" || target_os == "linux") {
+ if (is_clang) {
+ _default_toolchain = "//build/toolchain/linux:clang_$target_cpu"
+ } else {
+ _default_toolchain = "//build/toolchain/linux:gcc_$target_cpu"
+ }
+} else if (target_os == "mac") {
+ assert(host_os == "mac", "Cross-compiling on Mac is not supported.")
+ _default_toolchain = "//build/toolchain/mac:clang"
} else {
- # TODO(miu): Windows, and others.
- assert(false, "Toolchain for current_os is not defined.")
+ assert(false, "Toolchain for target_os is not defined.")
}
-set_default_toolchain(host_toolchain)
+set_default_toolchain(_default_toolchain)
# =============================================================================
# OS DEFINITIONS
@@ -132,8 +153,11 @@ _shared_binary_target_configs = [
"//build/config:no_rtti",
"//build/config:symbol_visibility_hidden",
"//build/config:default_sanitizers",
+ "//build/config:default_coverage",
"//build/config:compiler_defaults",
+ "//build/config:compiler_cpu_abi",
"//build/config:default_optimization",
+ "//build/config:sysroot_runtime_libraries",
]
# Apply that default list to the binary target types.
diff --git a/chromium/third_party/openscreen/src/build/config/arm.gni b/chromium/third_party/openscreen/src/build/config/arm.gni
new file mode 100644
index 00000000000..e52b20d14bd
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/config/arm.gni
@@ -0,0 +1,45 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+declare_args() {
+ # Version of the ARM processor when compiling on ARM. Ignored on non-ARM
+ # platforms.
+ arm_version = 7
+
+ # The ARM architecture. This will be a string like "armv6" or "armv7-a".
+ # An empty string means to use the default for the arm_version. Getting
+ # a proper list of supported architectures is challenging for clang, but
+ # can be found in the triple.h header in the LLVM source, under the
+ # SubArchType enum.
+ arm_arch = "armv7-a"
+
+ # The ARM floating point hardware. This will be a string like "neon" or
+ # "vfpv3".
+ arm_fpu = "vfpv3-d16"
+
+ # The ARM floating point mode. This is either the string "hard", "soft", or
+ # "softfp".
+ arm_float_abi = "hard"
+
+ # The ARM variant-specific tuning mode. This will be a string like "armv6"
+ # or "cortex-a15". Each cpu-type has a different tuning value.
+ arm_tune = "generic-armv7-a"
+}
+
+declare_args() {
+ # Whether to use the neon FPU instruction set or not. Actual value is set
+ # below, based on the arm_fpu argument.
+ arm_use_neon = arm_fpu == "neon"
+}
+
+if (current_cpu == "arm64") {
+ # arm64 supports only "hard".
+ arm_float_abi = "hard"
+ arm_fpu = "neon"
+ arm_use_neon = true
+}
+
+assert(arm_float_abi == "hard" || arm_float_abi == "soft" ||
+ arm_float_abi == "softfp")
+assert(arm_version == 6 || arm_version == 7 || arm_version == 8)
diff --git a/chromium/third_party/openscreen/src/build/config/external_libraries.gni b/chromium/third_party/openscreen/src/build/config/external_libraries.gni
new file mode 100644
index 00000000000..a451add7e19
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/config/external_libraries.gni
@@ -0,0 +1,48 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+declare_args() {
+ # FFMPEG: If installed on the system, set have_ffmpeg to true. This also
+ # requires the FFMPEG headers be installed. On Debian-like systems, this can
+ # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh`
+ # to install both FFMPEG and libSDL.
+ have_ffmpeg = false
+ ffmpeg_libs = [
+ "avcodec",
+ "avformat",
+ "avutil",
+ "swresample",
+ ]
+ ffmpeg_include_dirs = [] # Add only if headers are at non-standard locations.
+ ffmpeg_lib_dirs = [] # Add only if libraries are at non-standard locations.
+
+ # libopus: If installed on the system, set have_libopus to true. This also
+ # requires the libopus headers be installed. For example, on Debian-like
+ # systems, the following should install everything needed:
+ #
+ # sudo apt-get install libopus0 libopus-dev
+ have_libopus = false
+ libopus_libs = [ "opus" ]
+ libopus_include_dirs = [] # Add only if headers are at non-standard locations.
+ libopus_lib_dirs = [] # Add only if libraries are at non-standard locations.
+
+ # libsdl2: If installed on the system, set have_libsdl2 to true. This also
+ # requires the libSDL2 headers be installed. On Debian-like systems, this can
+ # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh`
+ # to install both FFMPEG and libSDL.
+ have_libsdl2 = false
+ libsdl2_libs = [ "SDL2" ]
+ libsdl2_include_dirs = [] # Add only if headers are at non-standard locations.
+ libsdl2_lib_dirs = [] # Add only if libraries are at non-standard locations.
+
+ # libvpx: If installed on the system, set have_libvpx to true. This also
+ # requires the libvpx headers be installed. For example, on Debian-like
+ # systems, the following should install everything needed:
+ #
+ # sudo apt-get install libvpx5 libvpx-dev
+ have_libvpx = false
+ libvpx_libs = [ "vpx" ]
+ libvpx_include_dirs = [] # Add only if headers are at non-standard locations.
+ libvpx_lib_dirs = [] # Add only if libraries are at non-standard locations.
+}
diff --git a/chromium/third_party/openscreen/src/build/config/sysroot.gni b/chromium/third_party/openscreen/src/build/config/sysroot.gni
new file mode 100644
index 00000000000..1434a5d670d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/config/sysroot.gni
@@ -0,0 +1,35 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file
+
+# This header file defines the "sysroot" variable which is the absolute path
+# of the sysroot. If no sysroot applies, the variable will be an empty string.
+
+declare_args() {
+ sysroot = ""
+
+ # The relative path to directory containing sysroot images
+ target_sysroot_dir = "../"
+
+ use_sysroot = current_cpu == "arm" || current_cpu == "arm64"
+}
+
+if (use_sysroot) {
+ # By default build against a sysroot image downloaded from Cloud Storage
+ # during gclient runhooks.
+ if (current_cpu == "arm") {
+ sysroot = "$target_sysroot_dir/debian_sid_arm-sysroot"
+ } else if (current_cpu == "arm64") {
+ sysroot = "$target_sysroot_dir/debian_sid_arm64-sysroot"
+ } else {
+ assert(false, "No linux sysroot for cpu: $target_cpu")
+ }
+ _script_arch = current_cpu
+
+ if (exec_script("//build/scripts/dir_exists.py",
+ [ rebase_path(sysroot) ],
+ "string") != "True") {
+ print("Missing sysroot for $current_cpu, downloading...")
+ exec_script("//build/scripts/install-sysroot.py", [ "$current_cpu" ])
+ }
+}
diff --git a/chromium/third_party/openscreen/src/build/scripts/dir_exists.py b/chromium/third_party/openscreen/src/build/scripts/dir_exists.py
new file mode 100755
index 00000000000..1e633d22b00
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/scripts/dir_exists.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Writes True if the argument is a directory.
+"""
+
+from __future__ import print_function
+
+import os.path
+import sys
+
+
+def main():
+ print(is_dir(sys.argv[1]), end='')
+ return 0
+
+
+def is_dir(dir_name):
+ return str(os.path.isdir(dir_name))
+
+
+def DoMain(args):
+ """Hook to be called from gyp without starting a separate python
+ interpreter."""
+ return is_dir(args[0])
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py b/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py
new file mode 100755
index 00000000000..8bf70b7092d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/scripts/install-sysroot.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python2
+
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Install Debian sysroots for cross compiling Open Screen.
+"""
+
+# The sysroot is needed to ensure that binaries that get built will run on
+# the oldest stable version of Debian that we currently support.
+# This script can be run manually but is more often run as part of gclient
+# hooks. When run from hooks this script is a no-op on non-linux platforms.
+
+# The sysroot image could be constructed from scratch based on the current state
+# of the Debian archive but for consistency we use a pre-built root image (we
+# don't want upstream changes to Debian to affect the build until we
+# choose to pull them in). The sysroot images are stored in Chrome's common
+# data storage, and the sysroots.json file should be kept in sync with Chrome's
+# copy of it.
+
+from __future__ import print_function
+
+import hashlib
+import json
+import platform
+import argparse
+import os
+import re
+import shutil
+import subprocess
+import sys
+try:
+ # For Python 3.0 and later
+ from urllib.request import urlopen
+except ImportError:
+ # Fall back to Python 2's urllib2
+ from urllib2 import urlopen
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+PARENT_DIR = os.path.dirname(SCRIPT_DIR)
+URL_PREFIX = 'https://storage.googleapis.com'
+URL_PATH = 'openscreen-sysroots'
+
+VALID_ARCHS = ('arm', 'arm64')
+
+ARCH_TRANSLATIONS = {
+ 'x64': 'amd64',
+ 'x86': 'i386',
+}
+
+DEFAULT_TARGET_PLATFORM = 'sid'
+
+
+class Error(Exception):
+ pass
+
+
+def GetSha1(filename):
+ """Generates a SHA1 hash for validating download. Done in chunks to avoid
+ excess memory usage."""
+ BLOCKSIZE = 1024 * 1024
+ sha1 = hashlib.sha1()
+ with open(filename, 'rb') as f:
+ chunk = f.read(BLOCKSIZE)
+ while chunk:
+ sha1.update(chunk)
+ chunk = f.read(BLOCKSIZE)
+ return sha1.hexdigest()
+
+
+def GetSysrootDict(target_platform, target_arch):
+ """Gets the sysroot information for a given platform and arch from the sysroots.json
+ file."""
+ if target_arch not in VALID_ARCHS:
+ raise Error('Unknown architecture: %s' % target_arch)
+
+ sysroots_file = os.path.join(SCRIPT_DIR, 'sysroots.json')
+ sysroots = json.load(open(sysroots_file))
+ sysroot_key = '%s_%s' % (target_platform, target_arch)
+ if sysroot_key not in sysroots:
+ raise Error('No sysroot for: %s' % (sysroot_key))
+ return sysroots[sysroot_key]
+
+
+def DownloadFile(url, local_path):
+ """Uses urllib to download a remote file into local_path."""
+ for _ in range(3):
+ try:
+ response = urlopen(url)
+ with open(local_path, "wb") as f:
+ f.write(response.read())
+ break
+ except Exception:
+ pass
+ else:
+ raise Error('Failed to download %s' % url)
+
+def ValidateFile(local_path, expected_sum):
+ """Generates the SHA1 hash of a local file to compare with an expected hashsum."""
+ sha1sum = GetSha1(local_path)
+ if sha1sum != expected_sum:
+ raise Error('Tarball sha1sum is wrong.'
+ 'Expected %s, actual: %s' % (expected_sum, sha1sum))
+
+def InstallSysroot(target_platform, target_arch):
+ """Downloads, validates, unpacks, and installs a sysroot image."""
+ sysroot_dict = GetSysrootDict(target_platform, target_arch)
+ tarball_filename = sysroot_dict['Tarball']
+ tarball_sha1sum = sysroot_dict['Sha1Sum']
+
+ sysroot = os.path.join(PARENT_DIR, sysroot_dict['SysrootDir'])
+
+ url = '%s/%s/%s/%s' % (URL_PREFIX, URL_PATH, tarball_sha1sum,
+ tarball_filename)
+
+ stamp = os.path.join(sysroot, '.stamp')
+ if os.path.exists(stamp):
+ with open(stamp) as s:
+ if s.read() == url:
+ return
+
+ if os.path.isdir(sysroot):
+ shutil.rmtree(sysroot)
+ os.mkdir(sysroot)
+
+ tarball_path = os.path.join(sysroot, tarball_filename)
+ DownloadFile(url, tarball_path)
+ ValidateFile(tarball_path, tarball_sha1sum)
+ subprocess.check_call(['tar', 'xf', tarball_path, '-C', sysroot])
+ os.remove(tarball_path)
+
+ with open(stamp, 'w') as s:
+ s.write(url)
+
+
+def parse_args(args):
+ """Parses the passed in arguments into an object."""
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ 'arch',
+ nargs=1,
+ help='Sysroot architecture: %s' % ', '.join(VALID_ARCHS))
+ p.add_argument(
+ '--print-hash', action="store_true",
+ help='Print the hash of the sysroot for the specified arch.')
+
+ return p.parse_args(args)
+
+
+def main(args):
+ if not (sys.platform.startswith('linux') or sys.platform == 'darwin'):
+ print('Unsupported platform. Only Linux and Mac OS X are supported.')
+ return 1
+
+ parsed_args = parse_args(args)
+ arch = ARCH_TRANSLATIONS.get(parsed_args.arch[0], parsed_args.arch[0])
+
+ if parsed_args.print_hash:
+ print(GetSysrootDict(DEFAULT_TARGET_PLATFORM, arch)['Sha1Sum'])
+
+ InstallSysroot(DEFAULT_TARGET_PLATFORM, arch)
+ return 0
+
+
+if __name__ == '__main__':
+ try:
+ sys.exit(main(sys.argv[1:]))
+ except Error as e:
+ sys.stderr.write('Installing sysroot error: {}\n'.format(e))
+ sys.exit(1)
diff --git a/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py b/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py
new file mode 100755
index 00000000000..8c65861046f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/scripts/sysroot_ld_path.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file
+
+# Replacement for the deprecated sysroot_ld_path.sh implementation in Chrome.
+"""
+Reads etc/ld.so.conf and/or etc/ld.so.conf.d/*.conf and returns the
+appropriate linker flags.
+"""
+
+from __future__ import print_function
+import argparse
+import glob
+import os
+import sys
+
+LD_SO_CONF_REL_PATH = "etc/ld.so.conf"
+LD_SO_CONF_D_REL_PATH = "etc/ld.so.conf.d"
+
+
+def parse_args(args):
+ p = argparse.ArgumentParser(__doc__)
+ p.add_argument('sysroot_path', nargs=1, help='Path to sysroot root folder')
+ return os.path.abspath(p.parse_args(args).sysroot_path[0])
+
+
+def process_entry(sysroot_path, entry):
+ assert (entry.startswith('/'))
+ print(os.path.join(sysroot_path, entry.strip()[1:]))
+
+def process_ld_conf_file(sysroot_path, conf_file_path):
+ with open(conf_file_path, 'r') as f:
+ for line in f.readlines():
+ if line.startswith('#'):
+ continue
+ process_entry(sysroot_path, line)
+
+
+def process_ld_conf_folder(sysroot_path, ld_conf_path):
+ files = glob.glob(os.path.join(ld_conf_path, '*.conf'))
+ for file in files:
+ process_ld_conf_file(sysroot_path, file)
+
+
+def process_ld_conf_files(sysroot_path):
+ conf_path = os.path.join(sysroot_path, LD_SO_CONF_REL_PATH)
+ conf_d_path = os.path.join(sysroot_path, LD_SO_CONF_D_REL_PATH)
+
+ if os.path.isdir(conf_path):
+ process_ld_conf_folder(sysroot_path, conf_path)
+ elif os.path.isdir(conf_d_path):
+ process_ld_conf_folder(sysroot_path, conf_d_path)
+
+
+def main(args):
+ sysroot_path = parse_args(args)
+ process_ld_conf_files(sysroot_path)
+
+
+if __name__ == '__main__':
+ try:
+ sys.exit(main(sys.argv[1:]))
+ except Exception as e:
+ sys.stderr.write(str(e) + '\n')
+ sys.exit(1)
diff --git a/chromium/third_party/openscreen/src/build/scripts/sysroots.json b/chromium/third_party/openscreen/src/build/scripts/sysroots.json
new file mode 100644
index 00000000000..81b98648d0f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/build/scripts/sysroots.json
@@ -0,0 +1,13 @@
+{
+ "sid_arm": {
+ "Sha1Sum": "07ef353ec66ca510a9b24a927db823e44435c764",
+ "SysrootDir": "debian_sid_arm-sysroot",
+ "Tarball": "debian_sid_arm-sysroot-demo-libs.tar.xz"
+ },
+ "sid_arm64": {
+ "Sha1Sum": "2ee3fb715f031df95b079ce6d2c8f5e8ba2cc6e4",
+ "SysrootDir": "debian_sid_arm64-sysroot",
+ "Tarball": "debian_sid_arm64_sysroot.tar.xz"
+ }
+
+}
diff --git a/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn b/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn
index e2205142c51..2767c7c41f7 100644
--- a/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn
+++ b/chromium/third_party/openscreen/src/build/toolchain/linux/BUILD.gn
@@ -2,110 +2,220 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
-toolchain("linux") {
- prefix = rebase_path("$clang_base_path/bin", root_build_dir)
+template("gcc_toolchain") {
+ toolchain(target_name) {
+ assert(defined(invoker.ar), "Caller must define ar command.")
+ assert(defined(invoker.cc), "Caller must define cc command.")
+ assert(defined(invoker.cxx), "Caller must define cxx command.")
+ assert(defined(invoker.ld), "Caller must define ld command.")
+ forward_variables_from(invoker,
+ [
+ "ar",
+ "cc",
+ "cxx",
+ "ld",
+ ])
+
+ toolchain_args = {
+ forward_variables_from(invoker.toolchain_args, "*")
+
+ # The host toolchain needs to be preserved by all secondary toolchains.
+ # For futher explanation, see
+ # https://gn.googlesource.com/gn/+/refs/heads/master/docs/reference.md#toolchain-overview
+ host_toolchain = host_toolchain
+ }
+
+ lib_switch = "-l"
+ lib_dir_switch = "-L"
+
+ object_prefix = "{{source_out_dir}}/{{label_name}}."
+
+ tool("cc") {
+ depfile = "{{output}}.d"
+ command = "$cc -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}"
+ depsformat = "gcc"
+ description = "CC {{output}}"
+ outputs = [
+ "$object_prefix{{source_name_part}}.o",
+ ]
+ }
+
+ tool("cxx") {
+ depfile = "{{output}}.d"
+ command = "$cxx -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_cc}} -c {{source}} -o {{output}}"
+ depsformat = "gcc"
+ description = "CXX {{output}}"
+ outputs = [
+ "$object_prefix{{source_name_part}}.o",
+ ]
+ }
+
+ tool("asm") {
+ depfile = "{{output}}.d"
+ command = "$cc -MMD -MF $depfile {{defines}} {{include_dirs}} {{asmflags}} -c {{source}} -o {{output}}"
+ depsformat = "gcc"
+ description = "ASM {{output}}"
+ outputs = [
+ "$object_prefix{{source_name_part}}.o",
+ ]
+ }
+
+ tool("alink") {
+ rspfile = "{{output}}.rsp"
+ command = "rm -f {{output}} && $ar rcs {{output}} @$rspfile"
+ description = "AR {{target_output_name}}{{output_extension}}"
+ rspfile_content = "{{inputs}}"
+ outputs = [
+ "{{output_dir}}/{{target_output_name}}{{output_extension}}",
+ ]
+ default_output_dir = "{{target_out_dir}}"
+ default_output_extension = ".a"
+ output_prefix = "lib"
+ }
+
+ tool("solink") {
+ soname = "{{target_output_name}}{{output_extension}}" # e.g. "libfoo.so".
+ sofile = "{{output_dir}}/$soname"
+ rspfile = soname + ".rsp"
+
+ command =
+ "$ld -shared {{ldflags}} -o $sofile -Wl,-soname=$soname @$rspfile"
+ rspfile_content = "-Wl,--whole-archive {{inputs}} {{solibs}} -Wl,--no-whole-archive {{libs}}"
+
+ description = "SOLINK {{output}}"
+
+ # Use this for {{output_extension}} expansions unless a target manually
+ # overrides it (in which case {{output_extension}} will be what the target
+ # specifies).
+ default_output_extension = ".so"
+
+ # Use this for {{output_dir}} expansions unless a target manually overrides
+ # it (in which case {{output_dir}} will be what the target specifies).
+ default_output_dir = "{{root_out_dir}}"
+
+ outputs = [
+ sofile,
+ ]
+ link_output = sofile
+ depend_output = sofile
+ output_prefix = "lib"
+ }
+
+ tool("link") {
+ outfile = "{{output_dir}}/{{target_output_name}}{{output_extension}}"
+ rspfile = "$outfile.rsp"
+
+ # These extra ldflags allow an executable to search for shared libraries in
+ # the current working directory.
+ additional_executable_ldflags = "-Wl,-rpath=\$ORIGIN/ -Wl,-rpath-link="
+ command = "$ld {{ldflags}} $additional_executable_ldflags -o $outfile -Wl,--start-group @$rspfile {{solibs}} -Wl,--end-group {{libs}}"
+ description = "LINK $outfile"
+ default_output_dir = "{{root_out_dir}}"
+ rspfile_content = "{{inputs}}"
+ outputs = [
+ outfile,
+ ]
+ }
+
+ tool("stamp") {
+ command = "touch {{output}}"
+ description = "STAMP {{output}}"
+ }
+
+ tool("copy") {
+ command = "ln -f {{source}} {{output}} 2>/dev/null || (rm -rf {{output}} && cp -af {{source}} {{output}})"
+ description = "COPY {{source}} {{output}}"
+ }
+ }
+}
- c_command = "$prefix/clang"
- cpp_command = "$prefix/clang++"
+template("clang_toolchain") {
+ prefix = rebase_path("$clang_base_path/bin", root_build_dir)
- if (is_gcc) {
- c_command = "gcc"
- cpp_command = "g++"
+ gcc_toolchain(target_name) {
+ ar = "$prefix/llvm-ar"
+ cc = "$prefix/clang"
+ cxx = "$prefix/clang++"
+ ld = cxx
+ toolchain_args = {
+ forward_variables_from(invoker.toolchain_args, "*")
+ is_clang = true
+ }
}
+}
- tool("cc") {
- depfile = "{{output}}.d"
- command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}"
- depsformat = "gcc"
- description = "CC {{output}}"
- outputs = [
- "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o",
- ]
+clang_toolchain("clang_x64") {
+ toolchain_args = {
+ current_cpu = "x64"
+ current_os = "linux"
}
+}
- tool("cxx") {
- depfile = "{{output}}.d"
- command = "$cpp_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_cc}} -c {{source}} -o {{output}}"
- depsformat = "gcc"
- description = "CXX {{output}}"
- outputs = [
- "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o",
- ]
+clang_toolchain("clang_x86") {
+ toolchain_args = {
+ current_cpu = "x86"
+ current_os = "linux"
}
+}
- tool("asm") {
- depfile = "{{output}}.d"
- command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{asmflags}} -c {{source}} -o {{output}}"
- depsformat = "gcc"
- description = "ASM {{output}}"
- outputs = [
- "{{source_out_dir}}/{{target_output_name}}.{{source_name_part}}.o",
- ]
+clang_toolchain("clang_arm") {
+ toolchain_args = {
+ current_cpu = "arm"
+ current_os = "linux"
}
+}
- tool("alink") {
- rspfile = "{{output}}.rsp"
- command = "rm -f {{output}} && ar rcs {{output}} @$rspfile"
- description = "AR {{target_output_name}}{{output_extension}}"
- rspfile_content = "{{inputs}}"
- outputs = [
- "{{output_dir}}/{{target_output_name}}{{output_extension}}",
- ]
- default_output_dir = "{{target_out_dir}}"
- default_output_extension = ".a"
- output_prefix = "lib"
+clang_toolchain("clang_arm64") {
+ toolchain_args = {
+ current_cpu = "arm64"
+ current_os = "linux"
}
+}
- tool("solink") {
- soname = "{{target_output_name}}{{output_extension}}" # e.g. "libfoo.so".
- sofile = "{{output_dir}}/$soname"
- rspfile = soname + ".rsp"
-
- command = "$cpp_command -shared {{ldflags}} -o $sofile -Wl,-soname=$soname @$rspfile"
- rspfile_content = "-Wl,--whole-archive {{inputs}} {{solibs}} -Wl,--no-whole-archive {{libs}}"
-
- description = "SOLINK {{output}}"
-
- # Use this for {{output_extension}} expansions unless a target manually
- # overrides it (in which case {{output_extension}} will be what the target
- # specifies).
- default_output_extension = ".so"
-
- # Use this for {{output_dir}} expansions unless a target manually overrides
- # it (in which case {{output_dir}} will be what the target specifies).
- default_output_dir = "{{root_out_dir}}"
-
- outputs = [
- sofile,
- ]
- link_output = sofile
- depend_output = sofile
- output_prefix = "lib"
+gcc_toolchain("gcc_x64") {
+ ar = "ar"
+ cc = "gcc"
+ cxx = "g++"
+ ld = cxx
+ toolchain_args = {
+ current_cpu = "x64"
+ current_os = "linux"
+ is_gcc = true
}
+}
- tool("link") {
- outfile = "{{target_output_name}}{{output_extension}}"
- rspfile = "$outfile.rsp"
-
- # These extra ldflags allow an executable to search for shared libraries in
- # the current working directory.
- additional_executable_ldflags = "-Wl,-rpath=\$ORIGIN/ -Wl,-rpath-link="
- command = "$cpp_command {{ldflags}} $additional_executable_ldflags -o $outfile -Wl,--start-group @$rspfile {{solibs}} -Wl,--end-group {{libs}}"
- description = "LINK $outfile"
- default_output_dir = "{{root_out_dir}}"
- rspfile_content = "{{inputs}}"
- outputs = [
- outfile,
- ]
+gcc_toolchain("gcc_x86") {
+ ar = "ar"
+ cc = "gcc"
+ cxx = "g++"
+ ld = cxx
+ toolchain_args = {
+ current_cpu = "x86"
+ current_os = "linux"
+ is_gcc = true
}
+}
- tool("stamp") {
- command = "touch {{output}}"
- description = "STAMP {{output}}"
+gcc_toolchain("gcc_arm") {
+ ar = "ar"
+ cc = "gcc"
+ cxx = "g++"
+ ld = cxx
+ toolchain_args = {
+ current_cpu = "arm"
+ current_os = "linux"
+ is_gcc = true
}
+}
- tool("copy") {
- command = "ln -f {{source}} {{output}} 2>/dev/null || (rm -rf {{output}} && cp -af {{source}} {{output}})"
- description = "COPY {{source}} {{output}}"
+gcc_toolchain("gcc_arm64") {
+ ar = "ar"
+ cc = "gcc"
+ cxx = "g++"
+ ld = cxx
+ toolchain_args = {
+ current_cpu = "arm64"
+ current_os = "linux"
+ is_gcc = true
}
}
diff --git a/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn b/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn
index f56dd499c7f..341d6a6356e 100644
--- a/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn
+++ b/chromium/third_party/openscreen/src/build/toolchain/mac/BUILD.gn
@@ -6,6 +6,9 @@ toolchain("clang") {
c_command = "clang"
cpp_command = "clang++"
+ lib_switch = "-l"
+ lib_dir_switch = "-L"
+
tool("cc") {
depfile = "{{output}}.d"
command = "$c_command -MMD -MF $depfile {{defines}} {{include_dirs}} {{cflags}} {{cflags_c}} -c {{source}} -o {{output}}"
diff --git a/chromium/third_party/openscreen/src/cast/DEPS b/chromium/third_party/openscreen/src/cast/DEPS
index e8615f96677..4e73fd07bcb 100644
--- a/chromium/third_party/openscreen/src/cast/DEPS
+++ b/chromium/third_party/openscreen/src/cast/DEPS
@@ -4,14 +4,12 @@ include_rules = [
# OSP code is strictly verboten.
'-api',
'-demo',
- '-discovery',
'-go',
'-msgs',
# Intra-libcast dependencies must be explicit.
'-cast',
- # All libcast code can use platform and cast/third_party.
+ # All libcast code can use cast/third_party.
'+cast/third_party',
- '+platform'
]
diff --git a/chromium/third_party/openscreen/src/cast/common/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/BUILD.gn
index 80260eb65b1..f753198c4fe 100644
--- a/chromium/third_party/openscreen/src/cast/common/BUILD.gn
+++ b/chromium/third_party/openscreen/src/cast/common/BUILD.gn
@@ -4,6 +4,7 @@
import("//build_overrides/build.gni")
import("//third_party/protobuf/proto_library.gni")
+import("../../testing/libfuzzer/fuzzer_test.gni")
source_set("certificate") {
sources = [
@@ -13,6 +14,8 @@ source_set("certificate") {
"certificate/cast_cert_validator_internal.h",
"certificate/cast_crl.cc",
"certificate/cast_crl.h",
+ "certificate/cast_trust_store.cc",
+ "certificate/cast_trust_store.h",
"certificate/types.cc",
"certificate/types.h",
]
@@ -32,9 +35,14 @@ source_set("channel") {
sources = [
"channel/cast_socket.cc",
"channel/cast_socket.h",
+ "channel/connection_namespace_handler.cc",
+ "channel/connection_namespace_handler.h",
"channel/message_framer.cc",
"channel/message_framer.h",
+ "channel/message_util.cc",
"channel/message_util.h",
+ "channel/namespace_router.cc",
+ "channel/namespace_router.h",
"channel/virtual_connection.h",
"channel/virtual_connection_manager.cc",
"channel/virtual_connection_manager.h",
@@ -43,29 +51,66 @@ source_set("channel") {
]
deps = [
- "../../util",
"certificate/proto:certificate_proto",
- "channel/proto:channel_proto",
]
public_deps = [
"../../platform",
"../../third_party/abseil",
+ "../../util",
+ "channel/proto:channel_proto",
+ ]
+}
+
+source_set("public") {
+ sources = [
+ "public/service_info.cc",
+ "public/service_info.h",
]
+
+ deps = [
+ "../../discovery:dnssd",
+ "../../discovery:public",
+ "../../platform",
+ "../../third_party/abseil",
+ ]
+}
+
+if (!build_with_chromium) {
+ source_set("discovery_e2e_test") {
+ testonly = true
+
+ sources = [
+ "discovery/e2e_test/tests.cc",
+ ]
+
+ deps = [
+ ":public",
+ "../../discovery:dnssd",
+ "../../discovery:public",
+ "../../third_party/googletest:gtest",
+ ]
+ }
}
source_set("test_helpers") {
testonly = true
+
sources = [
- "certificate/test_helpers.cc",
- "certificate/test_helpers.h",
- "channel/test/fake_cast_socket.h",
- "channel/test/mock_cast_message_handler.h",
+ "certificate/testing/test_helpers.cc",
+ "certificate/testing/test_helpers.h",
+ "channel/testing/fake_cast_socket.h",
+ "channel/testing/mock_cast_message_handler.h",
+ "channel/testing/mock_socket_error_handler.h",
+ "public/testing/discovery_utils.cc",
+ "public/testing/discovery_utils.h",
]
public_deps = [
":certificate",
":channel",
+ ":public",
"../../platform:test",
+ "../../third_party/abseil",
"../../third_party/boringssl",
"../../third_party/googletest:gmock",
]
@@ -81,16 +126,21 @@ source_set("unittests") {
"certificate/cast_cert_validator_unittest.cc",
"certificate/cast_crl_unittest.cc",
"channel/cast_socket_unittest.cc",
+ "channel/connection_namespace_handler_unittest.cc",
"channel/message_framer_unittest.cc",
+ "channel/namespace_router_unittest.cc",
"channel/virtual_connection_manager_unittest.cc",
"channel/virtual_connection_router_unittest.cc",
+ "public/service_info_unittest.cc",
]
deps = [
":certificate",
":channel",
+ ":public",
":test_helpers",
"../../platform",
+ "../../testing/util",
"../../third_party/boringssl",
"../../third_party/googletest:gmock",
"../../third_party/googletest:gtest",
@@ -98,4 +148,22 @@ source_set("unittests") {
"certificate/proto:certificate_unittest_proto",
"channel/proto:channel_proto",
]
+
+ data = [
+ "../../test/data/cast/common/certificate",
+ ]
+}
+
+openscreen_fuzzer_test("message_framer_fuzzer") {
+ sources = [
+ "channel/message_framer_fuzzer.cc",
+ ]
+ deps = [
+ ":channel",
+ ]
+
+ seed_corpus = "channel/message_framer_fuzzer_seeds"
+
+ # NOTE: 65536 is max _body_ size.
+ libfuzzer_options = [ "max_len=65600" ]
}
diff --git a/chromium/third_party/openscreen/src/cast/common/DEPS b/chromium/third_party/openscreen/src/cast/common/DEPS
index e1023950518..c31b1081666 100644
--- a/chromium/third_party/openscreen/src/cast/common/DEPS
+++ b/chromium/third_party/openscreen/src/cast/common/DEPS
@@ -2,5 +2,7 @@
include_rules = [
# libcast common code must depend on neither the sender nor the receiver.
- '+cast/common'
+ '+cast/common',
+ '+discovery/common',
+ '+discovery/public',
]
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc
index 6fe9582140f..6d2a12e7674 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc
@@ -17,27 +17,13 @@
#include "cast/common/certificate/cast_cert_validator_internal.h"
#include "cast/common/certificate/cast_crl.h"
+#include "cast/common/certificate/cast_trust_store.h"
+#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
namespace {
-using CastCertError = openscreen::Error::Code;
-
-// -------------------------------------------------------------------------
-// Cast trust anchors.
-// -------------------------------------------------------------------------
-
-// There are two trusted roots for Cast certificate chains:
-//
-// (1) CN=Cast Root CA (kCastRootCaDer)
-// (2) CN=Eureka Root CA (kEurekaRootCaDer)
-//
-// These constants are defined by the files included next:
-
-#include "cast/common/certificate/cast_root_ca_cert_der-inc.h"
-#include "cast/common/certificate/eureka_root_ca_der-inc.h"
-
// Returns the OID for the Audio-Only Cast policy
// (1.3.6.1.4.1.11129.2.5.2) in DER form.
const ConstDataSpan& AudioOnlyPolicyOid() {
@@ -49,8 +35,8 @@ const ConstDataSpan& AudioOnlyPolicyOid() {
class CertVerificationContextImpl final : public CertVerificationContext {
public:
- CertVerificationContextImpl(bssl::UniquePtr<EVP_PKEY>&& cert,
- std::string&& common_name)
+ CertVerificationContextImpl(bssl::UniquePtr<EVP_PKEY> cert,
+ std::string common_name)
: public_key_{std::move(cert)}, common_name_(std::move(common_name)) {}
~CertVerificationContextImpl() override = default;
@@ -143,54 +129,31 @@ CastDeviceCertPolicy GetAudioPolicy(const std::vector<X509*>& path) {
} // namespace
-class CastTrustStore {
- public:
- // Singleton for the Cast trust store for legacy networkingPrivate use.
- static CastTrustStore* GetInstance() {
- static CastTrustStore* store = new CastTrustStore();
- return store;
- }
-
- CastTrustStore() {
- trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer));
- trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer));
- }
- ~CastTrustStore() = default;
-
- TrustStore* trust_store() { return &trust_store_; }
-
- private:
- TrustStore trust_store_;
- OSP_DISALLOW_COPY_AND_ASSIGN(CastTrustStore);
-};
-
-openscreen::Error VerifyDeviceCert(
- const std::vector<std::string>& der_certs,
- const DateTime& time,
- std::unique_ptr<CertVerificationContext>* context,
- CastDeviceCertPolicy* policy,
- const CastCRL* crl,
- CRLPolicy crl_policy,
- TrustStore* trust_store) {
+Error VerifyDeviceCert(const std::vector<std::string>& der_certs,
+ const DateTime& time,
+ std::unique_ptr<CertVerificationContext>* context,
+ CastDeviceCertPolicy* policy,
+ const CastCRL* crl,
+ CRLPolicy crl_policy,
+ TrustStore* trust_store) {
if (!trust_store) {
trust_store = CastTrustStore::GetInstance()->trust_store();
}
// Fail early if CRL is required but not provided.
if (!crl && crl_policy == CRLPolicy::kCrlRequired) {
- return CastCertError::kErrCrlInvalid;
+ return Error::Code::kErrCrlInvalid;
}
CertificatePathResult result_path = {};
- openscreen::Error error =
- FindCertificatePath(der_certs, time, &result_path, trust_store);
+ Error error = FindCertificatePath(der_certs, time, &result_path, trust_store);
if (!error.ok()) {
return error;
}
if (crl_policy == CRLPolicy::kCrlRequired &&
!crl->CheckRevocation(result_path.path, time)) {
- return CastCertError::kErrCertsRevoked;
+ return Error::Code::kErrCertsRevoked;
}
*policy = GetAudioPolicy(result_path.path);
@@ -203,7 +166,7 @@ openscreen::Error VerifyDeviceCert(
int len = X509_NAME_get_text_by_NID(target_subject, NID_commonName,
&common_name[0], common_name.size());
if (len == 0) {
- return CastCertError::kErrCertsRestrictions;
+ return Error::Code::kErrCertsRestrictions;
}
common_name.resize(len);
@@ -211,8 +174,8 @@ openscreen::Error VerifyDeviceCert(
bssl::UniquePtr<EVP_PKEY>{X509_get_pubkey(result_path.target_cert.get())},
std::move(common_name)));
- return CastCertError::kNone;
+ return Error::Code::kNone;
}
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h
index 36f5eeff356..c20e42db69a 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h
@@ -13,8 +13,8 @@
#include "platform/base/error.h"
#include "platform/base/macros.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
class CastCRL;
@@ -70,7 +70,7 @@ class CertVerificationContext {
OSP_DISALLOW_COPY_AND_ASSIGN(CertVerificationContext);
};
-// Verifies a cast device certficate given a chain of DER-encoded certificates.
+// Verifies a cast device certificate given a chain of DER-encoded certificates.
//
// Inputs:
//
@@ -95,16 +95,15 @@ class CertVerificationContext {
//
// Outputs:
//
-// Returns openscreen::Error::Code::kNone on success. Otherwise, the
-// corresponding openscreen::Error::Code. On success, the output parameters are
-// filled with more details:
+// Returns Error::Code::kNone on success. Otherwise, the corresponding
+// Error::Code. On success, the output parameters are filled with more details:
//
// * |context| is filled with an object that can be used to verify signatures
// using the device certificate's public key, as well as to extract other
// properties from the device certificate (Common Name).
// * |policy| is filled with an indication of the device certificate's policy
// (i.e. is it for audio-only devices or is it unrestricted?)
-[[nodiscard]] openscreen::Error VerifyDeviceCert(
+[[nodiscard]] Error VerifyDeviceCert(
const std::vector<std::string>& der_certs,
const DateTime& time,
std::unique_ptr<CertVerificationContext>* context,
@@ -113,7 +112,7 @@ class CertVerificationContext {
CRLPolicy crl_policy,
TrustStore* trust_store = nullptr);
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc
index 331e71259a2..7e43e02c1fa 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc
@@ -15,14 +15,12 @@
#include "cast/common/certificate/types.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
namespace {
constexpr static int32_t kMinRsaModulusLengthBits = 2048;
-using CastCertError = openscreen::Error::Code;
-
// Stores intermediate state while attempting to find a valid certificate chain
// from a set of trusted certificates to a target certificate. Together, a
// sequence of these forms a certificate chain to be verified as well as a stack
@@ -63,16 +61,16 @@ uint8_t ParseAsn1TimeDoubleDigit(ASN1_GENERALIZEDTIME* time, int index) {
return (time->data[index] - '0') * 10 + (time->data[index + 1] - '0');
}
-CastCertError VerifyCertTime(X509* cert, const DateTime& time) {
+Error::Code VerifyCertTime(X509* cert, const DateTime& time) {
DateTime not_before;
DateTime not_after;
if (!GetCertValidTimeRange(cert, &not_before, &not_after)) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
if ((time < not_before) || (not_after < time)) {
- return CastCertError::kErrCertsDateInvalid;
+ return Error::Code::kErrCertsDateInvalid;
}
- return CastCertError::kNone;
+ return Error::Code::kNone;
}
bool VerifyPublicKeyLength(EVP_PKEY* public_key) {
@@ -94,27 +92,27 @@ bssl::UniquePtr<ASN1_BIT_STRING> GetKeyUsage(X509* cert) {
return bssl::UniquePtr<ASN1_BIT_STRING>{key_usage_bit_string};
}
-CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
- uint32_t step_index,
- const DateTime& time) {
+Error::Code VerifyCertificateChain(const std::vector<CertPathStep>& path,
+ uint32_t step_index,
+ const DateTime& time) {
// Default max path length is the number of intermediate certificates.
int max_pathlen = path.size() - 2;
std::vector<NAME_CONSTRAINTS*> path_name_constraints;
- CastCertError error = CastCertError::kNone;
+ Error::Code error = Error::Code::kNone;
uint32_t i = step_index;
for (; i < path.size() - 1; ++i) {
X509* subject = path[i + 1].cert;
X509* issuer = path[i].cert;
bool is_root = (i == step_index);
if (!is_root) {
- if ((error = VerifyCertTime(issuer, time)) != CastCertError::kNone) {
+ if ((error = VerifyCertTime(issuer, time)) != Error::Code::kNone) {
return error;
}
if (X509_NAME_cmp(X509_get_subject_name(issuer),
X509_get_issuer_name(issuer)) != 0) {
if (max_pathlen == 0) {
- return CastCertError::kErrCertsPathlen;
+ return Error::Code::kErrCertsPathlen;
}
--max_pathlen;
} else {
@@ -129,7 +127,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
const int bit =
ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kKeyCertSign);
if (bit == 0) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
@@ -138,7 +136,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
const int basic_constraints_index =
X509_get_ext_by_NID(issuer, NID_basic_constraints, -1);
if (basic_constraints_index == -1) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
X509_EXTENSION* const basic_constraints_extension =
X509_get_ext(issuer, basic_constraints_index);
@@ -147,16 +145,16 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
X509V3_EXT_d2i(basic_constraints_extension))};
if (!basic_constraints || !basic_constraints->ca) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
if (basic_constraints->pathlen) {
if (basic_constraints->pathlen->length != 1) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
} else {
const int pathlen = *basic_constraints->pathlen->data;
if (pathlen < 0) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
if (pathlen < max_pathlen) {
max_pathlen = pathlen;
@@ -165,12 +163,12 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
}
if (X509_ALGOR_cmp(issuer->sig_alg, issuer->cert_info->signature) != 0) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(issuer)};
if (!VerifyPublicKeyLength(public_key.get())) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
// NOTE: (!self-issued || target) -> verify name constraints. Target case
@@ -179,7 +177,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
if (!is_self_issued) {
for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) {
if (NAME_CONSTRAINTS_check(subject, name_constraints) != X509_V_OK) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
}
@@ -195,7 +193,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
issuer->nc = nc;
path_name_constraints.push_back(nc);
} else {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
}
@@ -220,12 +218,12 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
((OBJ_cmp(policy_mapping->issuerDomainPolicy, any_policy) == 0) ||
(OBJ_cmp(policy_mapping->subjectDomainPolicy, any_policy) == 0));
if (either_matches) {
- error = CastCertError::kErrCertsVerifyGeneric;
+ error = Error::Code::kErrCertsVerifyGeneric;
break;
}
}
sk_POLICY_MAPPING_free(policy_mappings);
- if (error != CastCertError::kNone) {
+ if (error != Error::Code::kNone) {
return error;
}
}
@@ -238,7 +236,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
const int nid = OBJ_obj2nid(extension->object);
if (nid != NID_name_constraints && nid != NID_basic_constraints &&
nid != NID_key_usage) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
}
@@ -259,7 +257,7 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
digest = EVP_sha512();
break;
default:
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
if (!VerifySignedData(
digest, public_key.get(),
@@ -267,14 +265,14 @@ CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path,
static_cast<uint32_t>(subject->cert_info->enc.len)},
{subject->signature->data,
static_cast<uint32_t>(subject->signature->length)})) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
// NOTE: Other half of ((!self-issued || target) -> check name constraints).
for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) {
if (NAME_CONSTRAINTS_check(path.back().cert, name_constraints) !=
X509_V_OK) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
}
return error;
@@ -370,12 +368,12 @@ bool VerifySignedData(const EVP_MD* digest,
data.data, data.length) == 1);
}
-openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
- const DateTime& time,
- CertificatePathResult* result_path,
- TrustStore* trust_store) {
+Error FindCertificatePath(const std::vector<std::string>& der_certs,
+ const DateTime& time,
+ CertificatePathResult* result_path,
+ TrustStore* trust_store) {
if (der_certs.empty()) {
- return CastCertError::kErrCertsMissing;
+ return Error::Code::kErrCertsMissing;
}
bssl::UniquePtr<X509>& target_cert = result_path->target_cert;
@@ -383,36 +381,36 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
result_path->intermediate_certs;
target_cert.reset(ParseX509Der(der_certs[0]));
if (!target_cert) {
- return CastCertError::kErrCertsParse;
+ return Error::Code::kErrCertsParse;
}
for (size_t i = 1; i < der_certs.size(); ++i) {
intermediate_certs.emplace_back(ParseX509Der(der_certs[i]));
if (!intermediate_certs.back()) {
- return CastCertError::kErrCertsParse;
+ return Error::Code::kErrCertsParse;
}
}
// Basic checks on the target certificate.
- CastCertError error = VerifyCertTime(target_cert.get(), time);
- if (error != CastCertError::kNone) {
+ Error::Code error = VerifyCertTime(target_cert.get(), time);
+ if (error != Error::Code::kNone) {
return error;
}
bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())};
if (!VerifyPublicKeyLength(public_key.get())) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
if (X509_ALGOR_cmp(target_cert.get()->sig_alg,
target_cert.get()->cert_info->signature) != 0) {
- return CastCertError::kErrCertsVerifyGeneric;
+ return Error::Code::kErrCertsVerifyGeneric;
}
bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get());
if (!key_usage) {
- return CastCertError::kErrCertsRestrictions;
+ return Error::Code::kErrCertsRestrictions;
}
int bit =
ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature);
if (bit == 0) {
- return CastCertError::kErrCertsRestrictions;
+ return Error::Code::kErrCertsRestrictions;
}
X509* path_head = target_cert.get();
@@ -442,7 +440,7 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
// returned is whatever the last error was from the last path tried.
uint32_t trust_store_index = 0;
uint32_t intermediate_cert_index = 0;
- CastCertError last_error = CastCertError::kNone;
+ Error::Code last_error = Error::Code::kNone;
for (;;) {
X509_NAME* target_issuer_name = X509_get_issuer_name(path_head);
@@ -486,8 +484,8 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
if (!next_issuer) {
if (path_index == first_index) {
// There are no more paths to try. Ensure an error is returned.
- if (last_error == CastCertError::kNone) {
- return CastCertError::kErrCertsVerifyGeneric;
+ if (last_error == Error::Code::kNone) {
+ return Error::Code::kErrCertsVerifyGeneric;
}
return last_error;
} else {
@@ -500,7 +498,7 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
if (path_cert_in_trust_store) {
last_error = VerifyCertificateChain(path, path_index, time);
- if (last_error != CastCertError::kNone) {
+ if (last_error != Error::Code::kNone) {
CertPathStep& last_step = path[path_index++];
trust_store_index = last_step.trust_store_index;
intermediate_cert_index = last_step.intermediate_cert_index;
@@ -517,8 +515,8 @@ openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
result_path->path.push_back(path[i].cert);
}
- return CastCertError::kNone;
+ return Error::Code::kNone;
}
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h
index 1c127e72948..f8424b6d1c0 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h
@@ -11,8 +11,8 @@
#include "platform/base/error.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
struct TrustStore {
std::vector<bssl::UniquePtr<X509>> certs;
@@ -26,6 +26,11 @@ bssl::UniquePtr<X509> MakeTrustAnchor(const uint8_t (&data)[N]) {
return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, N)};
}
+inline bssl::UniquePtr<X509> MakeTrustAnchor(const std::vector<uint8_t>& data) {
+ const uint8_t* dptr = data.data();
+ return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, data.size())};
+}
+
struct ConstDataSpan;
struct DateTime;
@@ -47,12 +52,12 @@ struct CertificatePathResult {
std::vector<X509*> path;
};
-openscreen::Error FindCertificatePath(const std::vector<std::string>& der_certs,
- const DateTime& time,
- CertificatePathResult* result_path,
- TrustStore* trust_store);
+Error FindCertificatePath(const std::vector<std::string>& der_certs,
+ const DateTime& time,
+ CertificatePathResult* result_path,
+ TrustStore* trust_store);
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_INTERNAL_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc
index c29de9cc6b7..41700a507cf 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc
@@ -8,16 +8,14 @@
#include <string.h>
#include "cast/common/certificate/cast_cert_validator_internal.h"
-#include "cast/common/certificate/test_helpers.h"
+#include "cast/common/certificate/testing/test_helpers.h"
#include "gtest/gtest.h"
#include "openssl/pem.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
namespace {
-using CastCertError = openscreen::Error::Code;
-
enum TrustStoreDependency {
// Uses the built-in trust store for Cast. This is how certificates are
// verified in production.
@@ -45,7 +43,7 @@ enum TrustStoreDependency {
// * |optional_signed_data_file_name| - optional path to a PEM file containing
// a valid signature generated by the device certificate.
//
-void RunTest(CastCertError expected_result,
+void RunTest(Error::Code expected_result,
const std::string& expected_common_name,
CastDeviceCertPolicy expected_policy,
const std::string& certs_file_name,
@@ -82,12 +80,11 @@ void RunTest(CastCertError expected_result,
std::unique_ptr<CertVerificationContext> context;
CastDeviceCertPolicy policy;
- openscreen::Error result =
- VerifyDeviceCert(certs, time, &context, &policy, nullptr,
- CRLPolicy::kCrlOptional, trust_store);
+ Error result = VerifyDeviceCert(certs, time, &context, &policy, nullptr,
+ CRLPolicy::kCrlOptional, trust_store);
ASSERT_EQ(expected_result, result.code());
- if (expected_result != CastCertError::kNone)
+ if (expected_result != Error::Code::kNone)
return;
EXPECT_EQ(expected_policy, policy);
@@ -166,7 +163,7 @@ DateTime MarchFirst2037() {
// Chains to trust anchor:
// Eureka Root CA (built-in trust store)
TEST(VerifyCastDeviceCertTest, ChromecastGen1) {
- RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A",
+ RunTest(Error::Code::kNone, "2ZZBG9 FA8FCA3EF91A",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/chromecast_gen1.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN,
@@ -181,7 +178,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen1) {
// Chains to trust anchor:
// Cast Root CA (built-in trust store)
TEST(VerifyCastDeviceCertTest, ChromecastGen1Reissue) {
- RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A",
+ RunTest(Error::Code::kNone, "2ZZBG9 FA8FCA3EF91A",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/chromecast_gen1_reissue.pem",
AprilFirst2016(), TRUST_STORE_BUILTIN,
@@ -196,7 +193,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen1Reissue) {
// Chains to trust anchor:
// Cast Root CA (built-in trust store)
TEST(VerifyCastDeviceCertTest, ChromecastGen2) {
- RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35",
+ RunTest(Error::Code::kNone, "3ZZAK6 FA8FCA3F0D35",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/chromecast_gen2.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
@@ -211,7 +208,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2) {
// Chains to trust anchor:
// Cast Root CA (built-in trust store)
TEST(VerifyCastDeviceCertTest, Fugu) {
- RunTest(CastCertError::kNone, "-6394818897508095075",
+ RunTest(Error::Code::kNone, "-6394818897508095075",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/fugu.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
@@ -226,7 +223,7 @@ TEST(VerifyCastDeviceCertTest, Fugu) {
//
// This is invalid because it does not chain to a trust anchor.
TEST(VerifyCastDeviceCertTest, Unchained) {
- RunTest(CastCertError::kErrCertsVerifyGeneric, "",
+ RunTest(Error::Code::kErrCertsVerifyGeneric, "",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/unchained.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
@@ -243,7 +240,7 @@ TEST(VerifyCastDeviceCertTest, Unchained) {
// trust anchors after all) it fails the test as it is not a *device
// certificate*.
TEST(VerifyCastDeviceCertTest, CastRootCa) {
- RunTest(CastCertError::kErrCertsRestrictions, "",
+ RunTest(Error::Code::kErrCertsRestrictions, "",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/cast_root_ca.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
@@ -260,7 +257,7 @@ TEST(VerifyCastDeviceCertTest, CastRootCa) {
// This device certificate has a policy that means it is valid only for audio
// devices.
TEST(VerifyCastDeviceCertTest, ChromecastAudio) {
- RunTest(CastCertError::kNone, "4ZZDZJ FA8FCA7EFE3C",
+ RunTest(Error::Code::kNone, "4ZZDZJ FA8FCA7EFE3C",
CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/chromecast_audio.pem",
AprilFirst2016(), TRUST_STORE_BUILTIN, "");
@@ -278,7 +275,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastAudio) {
// This device certificate has a policy that means it is valid only for audio
// devices.
TEST(VerifyCastDeviceCertTest, MtkAudioDev) {
- RunTest(CastCertError::kNone, "MediaTek Audio Dev Test",
+ RunTest(Error::Code::kNone, "MediaTek Audio Dev Test",
CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/mtk_audio_dev.pem", JanuaryFirst2015(),
TRUST_STORE_BUILTIN, "");
@@ -292,7 +289,7 @@ TEST(VerifyCastDeviceCertTest, MtkAudioDev) {
// Chains to trust anchor:
// Cast Root CA (built-in trust store)
TEST(VerifyCastDeviceCertTest, Vizio) {
- RunTest(CastCertError::kNone, "9V0000VB FA8FCA784D01",
+ RunTest(Error::Code::kNone, "9V0000VB FA8FCA784D01",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/vizio.pem", AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
@@ -305,17 +302,17 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2InvalidTime) {
// Control test - certificate should be valid at some time otherwise
// this test is pointless.
- RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35",
+ RunTest(Error::Code::kNone, "3ZZAK6 FA8FCA3F0D35",
CastDeviceCertPolicy::kUnrestricted, kCertsFile, AprilFirst2016(),
TRUST_STORE_BUILTIN, "");
// Use a time before notBefore.
- RunTest(CastCertError::kErrCertsDateInvalid, "",
+ RunTest(Error::Code::kErrCertsDateInvalid, "",
CastDeviceCertPolicy::kUnrestricted, kCertsFile, JanuaryFirst2015(),
TRUST_STORE_BUILTIN, "");
// Use a time after notAfter.
- RunTest(CastCertError::kErrCertsDateInvalid, "",
+ RunTest(Error::Code::kErrCertsDateInvalid, "",
CastDeviceCertPolicy::kUnrestricted, kCertsFile, MarchFirst2037(),
TRUST_STORE_BUILTIN, "");
}
@@ -332,7 +329,7 @@ TEST(VerifyCastDeviceCertTest, ChromecastGen2InvalidTime) {
// This device certificate has a policy that means it is valid only for audio
// devices.
TEST(VerifyCastDeviceCertTest, AudioRefDevTestChain3) {
- RunTest(CastCertError::kNone, "Audio Reference Dev Test",
+ RunTest(Error::Code::kNone, "Audio Reference Dev Test",
CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/audio_ref_dev_test_chain_3.pem",
AprilFirst2016(), TRUST_STORE_BUILTIN,
@@ -359,7 +356,7 @@ TEST(VerifyCastDeviceCertTest, AudioRefDevTestChain3) {
// This device certificate has a policy that means it is valid only for audio
// devices.
TEST(VerifyCastDeviceCertTest, IntermediateSerialNumberTooLong) {
- RunTest(CastCertError::kNone, "8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA",
+ RunTest(Error::Code::kNone, "8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA",
CastDeviceCertPolicy::AUDIO_ONLY,
"certificates/intermediate_serialnumber_toolong.pem",
AprilFirst2016(), TRUST_STORE_BUILTIN, "");
@@ -378,8 +375,7 @@ TEST(VerifyCastDeviceCertTest, IntermediateSerialNumberTooLong) {
TEST(VerifyCastDeviceCertTest, ExpiredTrustAnchor) {
// The root certificate is only valid in 2015, so validating with a time in
// 2016 means it is expired.
- RunTest(CastCertError::kNone, "CastDevice",
- CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "CastDevice", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/expired_root.pem", AprilFirst2016(),
TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -399,7 +395,7 @@ TEST(VerifyCastDeviceCertTest, ExpiredTrustAnchor) {
// Root (provided by test data; has pathlen=1 constraint)
TEST(VerifyCastDeviceCertTest, ViolatesPathlenTrustAnchorConstraint) {
// Test that the chain verification fails due to the pathlen constraint.
- RunTest(CastCertError::kErrCertsPathlen, "Target",
+ RunTest(Error::Code::kErrCertsPathlen, "Target",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/violates_root_pathlen_constraint.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -411,7 +407,7 @@ TEST(VerifyCastDeviceCertTest, ViolatesPathlenTrustAnchorConstraint) {
// Intermediate: policies={anyPolicy}
// Leaf: policies={anyPolicy}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAnypolicy) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX
"certificates/policies_ica_anypolicy_leaf_anypolicy.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -423,7 +419,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAnypolicy) {
// Intermediate: policies={anyPolicy}
// Leaf: policies={audioOnly}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAudioonly) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX
"certificates/policies_ica_anypolicy_leaf_audioonly.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -435,7 +431,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAudioonly) {
// Intermediate: policies={anyPolicy}
// Leaf: policies={foo}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafFoo) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_foo.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -446,7 +442,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafFoo) {
// Intermediate: policies={anyPolicy}
// Leaf: policies={}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafNone) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_none.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -457,7 +453,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafNone) {
// Intermediate: policies={audioOnly}
// Leaf: policies={anyPolicy}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAnypolicy) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX
"certificates/policies_ica_audioonly_leaf_anypolicy.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -469,7 +465,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAnypolicy) {
// Intermediate: policies={audioOnly}
// Leaf: policies={audioOnly}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAudioonly) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX
"certificates/policies_ica_audioonly_leaf_audioonly.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -481,7 +477,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAudioonly) {
// Intermediate: policies={audioOnly}
// Leaf: policies={foo}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafFoo) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_foo.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -492,7 +488,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafFoo) {
// Intermediate: policies={audioOnly}
// Leaf: policies={}
TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafNone) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_none.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -503,7 +499,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafNone) {
// Intermediate: policies={}
// Leaf: policies={anyPolicy}
TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAnypolicy) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_anypolicy.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -514,7 +510,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAnypolicy) {
// Intermediate: policies={}
// Leaf: policies={audioOnly}
TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAudioonly) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly,
TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_audioonly.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -525,7 +521,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAudioonly) {
// Intermediate: policies={}
// Leaf: policies={foo}
TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafFoo) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_foo.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -536,7 +532,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafFoo) {
// Intermediate: policies={}
// Leaf: policies={}
TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafNone) {
- RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_none.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -545,7 +541,7 @@ TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafNone) {
// 1024-bit RSA key. Verification should fail since the target's key is
// too weak.
TEST(VerifyCastDeviceCertTest, DeviceCertHas1024BitRsaKey) {
- RunTest(CastCertError::kErrCertsVerifyGeneric, "RSA 1024 Device Cert",
+ RunTest(Error::Code::kErrCertsVerifyGeneric, "RSA 1024 Device Cert",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/rsa1024_device_cert.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, "");
@@ -555,7 +551,7 @@ TEST(VerifyCastDeviceCertTest, DeviceCertHas1024BitRsaKey) {
// 2048-bit RSA key, and then verifying signed data (both SHA1 and SHA256)
// for it.
TEST(VerifyCastDeviceCertTest, DeviceCertHas2048BitRsaKey) {
- RunTest(CastCertError::kNone, "RSA 2048 Device Cert",
+ RunTest(Error::Code::kNone, "RSA 2048 Device Cert",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/rsa2048_device_cert.pem",
AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE,
@@ -566,7 +562,7 @@ TEST(VerifyCastDeviceCertTest, DeviceCertHas2048BitRsaKey) {
// nameConstraints extension but the leaf certificate is still permitted under
// these constraints.
TEST(VerifyCastDeviceCertTest, NameConstraintsObeyed) {
- RunTest(CastCertError::kNone, "Device", CastDeviceCertPolicy::kUnrestricted,
+ RunTest(Error::Code::kNone, "Device", CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/nc.pem", AprilFirst2020(),
TRUST_STORE_FROM_TEST_FILE, "");
}
@@ -575,12 +571,12 @@ TEST(VerifyCastDeviceCertTest, NameConstraintsObeyed) {
// nameConstraints extension and the leaf certificate is not permitted under
// these constraints.
TEST(VerifyCastDeviceCertTest, NameConstraintsViolated) {
- RunTest(CastCertError::kErrCertsVerifyGeneric, "Device",
+ RunTest(Error::Code::kErrCertsVerifyGeneric, "Device",
CastDeviceCertPolicy::kUnrestricted,
TEST_DATA_PREFIX "certificates/nc_fail.pem", AprilFirst2020(),
TRUST_STORE_FROM_TEST_FILE, "");
}
} // namespace
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc
index 5935e1e8e75..41f05cad4fb 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.cc
@@ -11,13 +11,13 @@
#include "absl/strings/string_view.h"
#include "cast/common/certificate/cast_cert_validator_internal.h"
-#include "cast/common/certificate/proto/revocation.pb.h"
#include "platform/base/macros.h"
+#include "util/crypto/certificate_utils.h"
#include "util/crypto/sha2.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
namespace {
enum CrlVersion {
@@ -76,7 +76,7 @@ bool VerifyCRL(const Crl& crl,
TrustStore* trust_store,
DateTime* overall_not_after) {
CertificatePathResult result_path = {};
- openscreen::Error error =
+ Error error =
FindCertificatePath({crl.signer_cert()}, time, &result_path, trust_store);
if (!error.ok()) {
return false;
@@ -134,39 +134,12 @@ bool VerifyCRL(const Crl& crl,
return true;
}
-std::string GetSpkiTlv(X509* cert) {
- int len = i2d_X509_PUBKEY(cert->cert_info->key, nullptr);
- if (len <= 0) {
- return {};
- }
- std::string x(len, 0);
- uint8_t* data = reinterpret_cast<uint8_t*>(&x[0]);
- if (!i2d_X509_PUBKEY(cert->cert_info->key, &data)) {
- return {};
- }
- size_t actual_size = data - reinterpret_cast<uint8_t*>(&x[0]);
- OSP_DCHECK_EQ(actual_size, x.size());
- x.resize(actual_size);
- return x;
-}
-
-bool ParseDerUint64(ASN1_INTEGER* asn1int, uint64_t* result) {
- if (asn1int->length > 8 || asn1int->length == 0) {
- return false;
- }
- *result = 0;
- for (int i = 0; i < asn1int->length; ++i) {
- *result = (*result << 8) | asn1int->data[i];
- }
- return true;
-}
-
} // namespace
CastCRL::CastCRL(const TbsCrl& tbs_crl, const DateTime& overall_not_after) {
// Parse the validity information.
- // Assume ConvertTimeSeconds will succeed. Successful call to VerifyCRL
- // means that these calls were successful.
+ // Assume DateTimeFromSeconds will succeed. Successful call to VerifyCRL means
+ // that these calls were successful.
DateTimeFromSeconds(tbs_crl.not_before_seconds(), &not_before_);
DateTimeFromSeconds(tbs_crl.not_after_seconds(), &not_after_);
if (overall_not_after < not_after_) {
@@ -210,8 +183,7 @@ bool CastCRL::CheckRevocation(const std::vector<X509*>& trusted_chain,
return false;
}
- openscreen::ErrorOr<std::string> spki_hash =
- openscreen::SHA256HashString(spki_tlv);
+ ErrorOr<std::string> spki_hash = SHA256HashString(spki_tlv);
if (spki_hash.is_error() ||
(revoked_hashes_.find(spki_hash.value()) != revoked_hashes_.end())) {
return false;
@@ -226,10 +198,12 @@ bool CastCRL::CheckRevocation(const std::vector<X509*>& trusted_chain,
// Only Google generated device certificates will be revoked by range.
// These will always be less than 64 bits in length.
- if (!ParseDerUint64(subordinate->cert_info->serialNumber,
- &serial_number)) {
+ ErrorOr<uint64_t> maybe_serial =
+ ParseDerUint64(subordinate->cert_info->serialNumber);
+ if (!maybe_serial) {
continue;
}
+ serial_number = maybe_serial.value();
for (const auto& revoked_serial : issuer_iter->second) {
if (revoked_serial.first_serial <= serial_number &&
revoked_serial.last_serial >= serial_number) {
@@ -273,5 +247,5 @@ std::unique_ptr<CastCRL> ParseAndVerifyCRL(const std::string& crl_proto,
return nullptr;
}
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h
index 51b8fa5169c..420aa38e14d 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl.h
@@ -17,8 +17,15 @@
#include "cast/common/certificate/proto/revocation.pb.h"
#include "platform/base/macros.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
+
+// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to
+// openscreen::cast
+using CrlBundle = ::cast::certificate::CrlBundle;
+using Crl = ::cast::certificate::Crl;
+using TbsCrl = ::cast::certificate::TbsCrl;
+using SerialNumberRange = ::cast::certificate::SerialNumberRange;
// This class represents the certificate revocation list information parsed from
// the binary in a protobuf message.
@@ -81,7 +88,7 @@ std::unique_ptr<CastCRL> ParseAndVerifyCRL(const std::string& crl_proto,
const DateTime& time,
TrustStore* trust_store = nullptr);
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CERTIFICATE_CAST_CRL_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc
index 4e3a94d8d8c..81c0030f854 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_crl_unittest.cc
@@ -7,15 +7,21 @@
#include "cast/common/certificate/cast_cert_validator.h"
#include "cast/common/certificate/cast_cert_validator_internal.h"
#include "cast/common/certificate/proto/test_suite.pb.h"
-#include "cast/common/certificate/test_helpers.h"
+#include "cast/common/certificate/testing/test_helpers.h"
#include "gtest/gtest.h"
+#include "testing/util/read_file.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
-namespace {
-using CastCertError = openscreen::Error::Code;
+// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to
+// openscreen::cast
+using DeviceCertTestSuite = ::cast::certificate::DeviceCertTestSuite;
+using VerificationResult = ::cast::certificate::VerificationResult;
+using DeviceCertTest = ::cast::certificate::DeviceCertTest;
+
+namespace {
// Indicates the expected result of test step's verification.
enum TestStepResult {
@@ -31,10 +37,9 @@ bool TestVerifyCertificate(TestStepResult expected_result,
TrustStore* cast_trust_store) {
std::unique_ptr<CertVerificationContext> context;
CastDeviceCertPolicy policy;
- openscreen::Error result =
- VerifyDeviceCert(der_certs, time, &context, &policy, nullptr,
- CRLPolicy::kCrlOptional, cast_trust_store);
- bool success = (result.code() == CastCertError::kNone) ==
+ Error result = VerifyDeviceCert(der_certs, time, &context, &policy, nullptr,
+ CRLPolicy::kCrlOptional, cast_trust_store);
+ bool success = (result.code() == Error::Code::kNone) ==
(expected_result == kResultSuccess);
EXPECT_TRUE(success);
return success;
@@ -60,7 +65,7 @@ bool TestVerifyCRL(TestStepResult expected_result,
// The provided CRL is verified at |crl_time|.
// If |crl_required| is set, then a valid Cast CRL must be provided.
// Otherwise, a missing CRL is be ignored.
-bool TestVerifyRevocation(CastCertError expected_result,
+bool TestVerifyRevocation(Error::Code expected_result,
const std::vector<std::string>& der_certs,
const std::string& crl_bundle,
const DateTime& crl_time,
@@ -78,9 +83,8 @@ bool TestVerifyRevocation(CastCertError expected_result,
CastDeviceCertPolicy policy;
CRLPolicy crl_policy =
crl_required ? CRLPolicy::kCrlRequired : CRLPolicy::kCrlOptional;
- openscreen::Error result =
- VerifyDeviceCert(der_certs, cert_time, &context, &policy, crl.get(),
- crl_policy, cast_trust_store);
+ Error result = VerifyDeviceCert(der_certs, cert_time, &context, &policy,
+ crl.get(), crl_policy, cast_trust_store);
EXPECT_EQ(expected_result, result.code());
return expected_result == result.code();
}
@@ -118,47 +122,47 @@ bool RunTest(const DeviceCertTest& test_case) {
std::string crl_bundle = test_case.crl_bundle();
switch (test_case.expected_result()) {
- case PATH_VERIFICATION_FAILED:
+ case ::cast::certificate::PATH_VERIFICATION_FAILED:
return TestVerifyCertificate(kResultFail, der_cert_path,
cert_verification_time,
cast_trust_store.get());
- case CRL_VERIFICATION_FAILED:
+ case ::cast::certificate::CRL_VERIFICATION_FAILED:
return TestVerifyCRL(kResultFail, crl_bundle, crl_verification_time,
crl_trust_store.get());
- case REVOCATION_CHECK_FAILED_WITHOUT_CRL:
+ case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL:
return TestVerifyCertificate(kResultSuccess, der_cert_path,
cert_verification_time,
cast_trust_store.get()) &&
TestVerifyCRL(kResultFail, crl_bundle, crl_verification_time,
crl_trust_store.get()) &&
TestVerifyRevocation(
- CastCertError::kErrCrlInvalid, der_cert_path, crl_bundle,
+ Error::Code::kErrCrlInvalid, der_cert_path, crl_bundle,
crl_verification_time, cert_verification_time, true,
cast_trust_store.get(), crl_trust_store.get());
- case CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: // fallthrough
- case REVOCATION_CHECK_FAILED:
+ case ::cast::certificate::
+ CRL_EXPIRED_AFTER_INITIAL_VERIFICATION: // fallthrough
+ case ::cast::certificate::REVOCATION_CHECK_FAILED:
return TestVerifyCertificate(kResultSuccess, der_cert_path,
cert_verification_time,
cast_trust_store.get()) &&
TestVerifyCRL(kResultSuccess, crl_bundle, crl_verification_time,
crl_trust_store.get()) &&
TestVerifyRevocation(
- CastCertError::kErrCertsRevoked, der_cert_path, crl_bundle,
+ Error::Code::kErrCertsRevoked, der_cert_path, crl_bundle,
crl_verification_time, cert_verification_time, true,
cast_trust_store.get(), crl_trust_store.get());
- case SUCCESS:
+ case ::cast::certificate::SUCCESS:
return (crl_bundle.empty() ||
TestVerifyCRL(kResultSuccess, crl_bundle, crl_verification_time,
crl_trust_store.get())) &&
TestVerifyCertificate(kResultSuccess, der_cert_path,
cert_verification_time,
cast_trust_store.get()) &&
- TestVerifyRevocation(CastCertError::kNone, der_cert_path,
- crl_bundle, crl_verification_time,
- cert_verification_time, !crl_bundle.empty(),
- cast_trust_store.get(),
+ TestVerifyRevocation(Error::Code::kNone, der_cert_path, crl_bundle,
+ crl_verification_time, cert_verification_time,
+ !crl_bundle.empty(), cast_trust_store.get(),
crl_trust_store.get());
- case UNSPECIFIED:
+ case ::cast::certificate::UNSPECIFIED:
return false;
}
return false;
@@ -169,8 +173,7 @@ bool RunTest(const DeviceCertTest& test_case) {
// To see the description of the test, execute the test.
// These tests are generated by a test generator in google3.
void RunTestSuite(const std::string& test_suite_file_name) {
- std::string testsuite_raw =
- testing::ReadEntireFileToString(test_suite_file_name);
+ std::string testsuite_raw = ReadEntireFileToString(test_suite_file_name);
ASSERT_FALSE(testsuite_raw.empty());
DeviceCertTestSuite test_suite;
ASSERT_TRUE(test_suite.ParseFromString(testsuite_raw));
@@ -191,5 +194,5 @@ TEST(CastCertificateTest, TestSuite1) {
}
} // namespace
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc
new file mode 100644
index 00000000000..8c9e5e24b63
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/certificate/cast_trust_store.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// -------------------------------------------------------------------------
+// Cast trust anchors.
+// -------------------------------------------------------------------------
+
+// There are two trusted roots for Cast certificate chains:
+//
+// (1) CN=Cast Root CA (kCastRootCaDer)
+// (2) CN=Eureka Root CA (kEurekaRootCaDer)
+//
+// These constants are defined by the files included next:
+
+#include "cast/common/certificate/cast_root_ca_cert_der-inc.h"
+#include "cast/common/certificate/eureka_root_ca_der-inc.h"
+
+} // namespace
+
+// static
+CastTrustStore* CastTrustStore::GetInstance() {
+ if (!store_) {
+ store_ = new CastTrustStore();
+ }
+ return store_;
+}
+
+// static
+void CastTrustStore::ResetInstance() {
+ delete store_;
+ store_ = nullptr;
+}
+
+// static
+CastTrustStore* CastTrustStore::CreateInstanceForTest(
+ const std::vector<uint8_t>& trust_anchor_der) {
+ OSP_DCHECK(!store_);
+ store_ = new CastTrustStore(trust_anchor_der);
+ return store_;
+}
+
+CastTrustStore::CastTrustStore() {
+ trust_store_.certs.emplace_back(MakeTrustAnchor(kCastRootCaDer));
+ trust_store_.certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer));
+}
+
+CastTrustStore::CastTrustStore(const std::vector<uint8_t>& trust_anchor_der) {
+ trust_store_.certs.emplace_back(MakeTrustAnchor(trust_anchor_der));
+}
+
+CastTrustStore::~CastTrustStore() = default;
+
+// static
+CastTrustStore* CastTrustStore::store_ = nullptr;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h
new file mode 100644
index 00000000000..8aac9d3905b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h
@@ -0,0 +1,39 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
+#define CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
+
+#include <vector>
+
+#include "cast/common/certificate/cast_cert_validator_internal.h"
+
+namespace openscreen {
+namespace cast {
+
+class CastTrustStore {
+ public:
+ static CastTrustStore* GetInstance();
+ static void ResetInstance();
+
+ static CastTrustStore* CreateInstanceForTest(
+ const std::vector<uint8_t>& trust_anchor_der);
+
+ CastTrustStore();
+ CastTrustStore(const std::vector<uint8_t>& trust_anchor_der);
+ CastTrustStore(const CastTrustStore&) = delete;
+ ~CastTrustStore();
+ CastTrustStore& operator=(const CastTrustStore&) = delete;
+
+ TrustStore* trust_store() { return &trust_store_; }
+
+ private:
+ static CastTrustStore* store_;
+ TrustStore trust_store_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CERTIFICATE_CAST_TRUST_STORE_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc b/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc
deleted file mode 100644
index 41aef74feef..00000000000
--- a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.cc
+++ /dev/null
@@ -1,130 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "cast/common/certificate/test_helpers.h"
-
-#include <openssl/pem.h>
-#include <stdio.h>
-#include <string.h>
-
-#include "util/logging.h"
-
-namespace cast {
-namespace certificate {
-namespace testing {
-
-std::string ReadEntireFileToString(const std::string& filename) {
- FILE* file = fopen(filename.c_str(), "r");
- if (file == nullptr) {
- return {};
- }
- fseek(file, 0, SEEK_END);
- long file_size = ftell(file);
- fseek(file, 0, SEEK_SET);
- std::string contents(file_size, 0);
- int bytes_read = 0;
- while (bytes_read < file_size) {
- size_t ret = fread(&contents[bytes_read], 1, file_size - bytes_read, file);
- if (ret == 0 && ferror(file)) {
- return {};
- } else {
- bytes_read += ret;
- }
- }
- fclose(file);
-
- return contents;
-}
-
-std::vector<std::string> ReadCertificatesFromPemFile(
- const std::string& filename) {
- FILE* fp = fopen(filename.c_str(), "r");
- if (!fp) {
- return {};
- }
- std::vector<std::string> certs;
-#define STRCMP_LITERAL(s, l) strncmp(s, l, sizeof(l))
- for (;;) {
- char* name;
- char* header;
- unsigned char* data;
- long length;
- if (PEM_read(fp, &name, &header, &data, &length) == 1) {
- if (STRCMP_LITERAL(name, "CERTIFICATE") == 0) {
- certs.emplace_back((char*)data, length);
- }
- OPENSSL_free(name);
- OPENSSL_free(header);
- OPENSSL_free(data);
- } else {
- break;
- }
- }
- fclose(fp);
- return certs;
-}
-
-SignatureTestData::SignatureTestData()
- : message{nullptr, 0}, sha1{nullptr, 0}, sha256{nullptr, 0} {}
-
-SignatureTestData::~SignatureTestData() {
- OPENSSL_free(const_cast<uint8_t*>(message.data));
- OPENSSL_free(const_cast<uint8_t*>(sha1.data));
- OPENSSL_free(const_cast<uint8_t*>(sha256.data));
-}
-
-SignatureTestData ReadSignatureTestData(const std::string& filename) {
- FILE* fp = fopen(filename.c_str(), "r");
- OSP_DCHECK(fp);
- SignatureTestData result = {};
- for (;;) {
- char* name;
- char* header;
- unsigned char* data;
- long length;
- if (PEM_read(fp, &name, &header, &data, &length) == 1) {
- if (strcmp(name, "MESSAGE") == 0) {
- OSP_DCHECK(!result.message.data);
- result.message.data = data;
- result.message.length = length;
- } else if (strcmp(name, "SIGNATURE SHA1") == 0) {
- OSP_DCHECK(!result.sha1.data);
- result.sha1.data = data;
- result.sha1.length = length;
- } else if (strcmp(name, "SIGNATURE SHA256") == 0) {
- OSP_DCHECK(!result.sha256.data);
- result.sha256.data = data;
- result.sha256.length = length;
- } else {
- OPENSSL_free(data);
- }
- OPENSSL_free(name);
- OPENSSL_free(header);
- } else {
- break;
- }
- }
- OSP_DCHECK(result.message.data);
- OSP_DCHECK(result.sha1.data);
- OSP_DCHECK(result.sha256.data);
-
- return result;
-}
-
-std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile(
- const std::string& filename) {
- std::unique_ptr<TrustStore> store = std::make_unique<TrustStore>();
-
- std::vector<std::string> certs =
- testing::ReadCertificatesFromPemFile(filename);
- for (const auto& der_cert : certs) {
- const uint8_t* data = (const uint8_t*)der_cert.data();
- store->certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size()));
- }
- return store;
-}
-
-} // namespace testing
-} // namespace certificate
-} // namespace cast
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc
new file mode 100644
index 00000000000..eb9bac2fb92
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.cc
@@ -0,0 +1,130 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/certificate/testing/test_helpers.h"
+
+#include <openssl/bytestring.h>
+#include <openssl/pem.h>
+#include <openssl/rsa.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "absl/strings/match.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace testing {
+
+std::vector<std::string> ReadCertificatesFromPemFile(
+ absl::string_view filename) {
+ FILE* fp = fopen(filename.data(), "r");
+ if (!fp) {
+ return {};
+ }
+ std::vector<std::string> certs;
+ char* name;
+ char* header;
+ unsigned char* data;
+ long length;
+ while (PEM_read(fp, &name, &header, &data, &length) == 1) {
+ if (absl::StartsWith(name, "CERTIFICATE")) {
+ certs.emplace_back((char*)data, length);
+ }
+ OPENSSL_free(name);
+ OPENSSL_free(header);
+ OPENSSL_free(data);
+ }
+ fclose(fp);
+ return certs;
+}
+
+bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename) {
+ FILE* fp = fopen(filename.data(), "r");
+ if (!fp) {
+ return nullptr;
+ }
+ bssl::UniquePtr<EVP_PKEY> pkey;
+ char* name;
+ char* header;
+ unsigned char* data;
+ long length;
+ while (PEM_read(fp, &name, &header, &data, &length) == 1) {
+ if (absl::StartsWith(name, "RSA PRIVATE KEY")) {
+ OSP_DCHECK(!pkey);
+ CBS cbs;
+ CBS_init(&cbs, data, length);
+ RSA* rsa = RSA_parse_private_key(&cbs);
+ if (rsa) {
+ pkey.reset(EVP_PKEY_new());
+ EVP_PKEY_assign_RSA(pkey.get(), rsa);
+ }
+ }
+ OPENSSL_free(name);
+ OPENSSL_free(header);
+ OPENSSL_free(data);
+ }
+ fclose(fp);
+ return pkey;
+}
+
+SignatureTestData::SignatureTestData()
+ : message{nullptr, 0}, sha1{nullptr, 0}, sha256{nullptr, 0} {}
+
+SignatureTestData::~SignatureTestData() {
+ OPENSSL_free(const_cast<uint8_t*>(message.data));
+ OPENSSL_free(const_cast<uint8_t*>(sha1.data));
+ OPENSSL_free(const_cast<uint8_t*>(sha256.data));
+}
+
+SignatureTestData ReadSignatureTestData(absl::string_view filename) {
+ FILE* fp = fopen(filename.data(), "r");
+ OSP_DCHECK(fp);
+ SignatureTestData result = {};
+ char* name;
+ char* header;
+ unsigned char* data;
+ long length;
+ while (PEM_read(fp, &name, &header, &data, &length) == 1) {
+ if (strcmp(name, "MESSAGE") == 0) {
+ OSP_DCHECK(!result.message.data);
+ result.message.data = data;
+ result.message.length = length;
+ } else if (strcmp(name, "SIGNATURE SHA1") == 0) {
+ OSP_DCHECK(!result.sha1.data);
+ result.sha1.data = data;
+ result.sha1.length = length;
+ } else if (strcmp(name, "SIGNATURE SHA256") == 0) {
+ OSP_DCHECK(!result.sha256.data);
+ result.sha256.data = data;
+ result.sha256.length = length;
+ } else {
+ OPENSSL_free(data);
+ }
+ OPENSSL_free(name);
+ OPENSSL_free(header);
+ }
+ OSP_DCHECK(result.message.data);
+ OSP_DCHECK(result.sha1.data);
+ OSP_DCHECK(result.sha256.data);
+
+ return result;
+}
+
+std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile(
+ absl::string_view filename) {
+ std::unique_ptr<TrustStore> store = std::make_unique<TrustStore>();
+
+ std::vector<std::string> certs =
+ testing::ReadCertificatesFromPemFile(filename);
+ for (const auto& der_cert : certs) {
+ const uint8_t* data = (const uint8_t*)der_cert.data();
+ store->certs.emplace_back(d2i_X509(nullptr, &data, der_cert.size()));
+ }
+ return store;
+}
+
+} // namespace testing
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.h b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.h
index ac3a136d65f..c1ff9a25f78 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/test_helpers.h
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/testing/test_helpers.h
@@ -2,22 +2,25 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_
-#define CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_
+#ifndef CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_
+#define CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_
+
+#include <openssl/evp.h>
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "cast/common/certificate/cast_cert_validator_internal.h"
#include "cast/common/certificate/types.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
namespace testing {
-std::string ReadEntireFileToString(const std::string& filename);
std::vector<std::string> ReadCertificatesFromPemFile(
- const std::string& filename);
+ absl::string_view filename);
+bssl::UniquePtr<EVP_PKEY> ReadKeyFromPemFile(absl::string_view filename);
class SignatureTestData {
public:
@@ -29,13 +32,13 @@ class SignatureTestData {
ConstDataSpan sha256;
};
-SignatureTestData ReadSignatureTestData(const std::string& filename);
+SignatureTestData ReadSignatureTestData(absl::string_view filename);
std::unique_ptr<TrustStore> CreateTrustStoreFromPemFile(
- const std::string& filename);
+ absl::string_view filename);
} // namespace testing
-} // namespace certificate
} // namespace cast
+} // namespace openscreen
-#endif // CAST_COMMON_CERTIFICATE_TEST_HELPERS_H_
+#endif // CAST_COMMON_CERTIFICATE_TESTING_TEST_HELPERS_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/types.cc b/chromium/third_party/openscreen/src/cast/common/certificate/types.cc
index 297fbffca43..507a033ec1c 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/types.cc
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/types.cc
@@ -6,8 +6,8 @@
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace certificate {
bool operator<(const DateTime& a, const DateTime& b) {
if (a.year < b.year) {
@@ -66,5 +66,22 @@ bool DateTimeFromSeconds(uint64_t seconds, DateTime* time) {
return true;
}
-} // namespace certificate
+static_assert(sizeof(time_t) >= 4, "Can't avoid overflow with < 32-bits");
+
+std::chrono::seconds DateTimeToSeconds(const DateTime& time) {
+ OSP_DCHECK_GE(time.month, 1);
+ OSP_DCHECK_GE(time.year, 1900);
+ // NOTE: Guard against overflow if time_t is 32-bit.
+ OSP_DCHECK(sizeof(time_t) >= 8 || time.year < 2038) << time.year;
+ struct tm tm = {};
+ tm.tm_sec = time.second;
+ tm.tm_min = time.minute;
+ tm.tm_hour = time.hour;
+ tm.tm_mday = time.day;
+ tm.tm_mon = time.month - 1;
+ tm.tm_year = time.year - 1900;
+ return std::chrono::seconds(mktime(&tm));
+}
+
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/types.h b/chromium/third_party/openscreen/src/cast/common/certificate/types.h
index 62fe45ed5d4..c3cf7efe273 100644
--- a/chromium/third_party/openscreen/src/cast/common/certificate/types.h
+++ b/chromium/third_party/openscreen/src/cast/common/certificate/types.h
@@ -7,8 +7,10 @@
#include <stdint.h>
+#include <chrono>
+
+namespace openscreen {
namespace cast {
-namespace certificate {
struct ConstDataSpan {
const uint8_t* data;
@@ -28,7 +30,10 @@ bool operator<(const DateTime& a, const DateTime& b);
bool operator>(const DateTime& a, const DateTime& b);
bool DateTimeFromSeconds(uint64_t seconds, DateTime* time);
-} // namespace certificate
+// |time| is assumed to be valid.
+std::chrono::seconds DateTimeToSeconds(const DateTime& time);
+
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CERTIFICATE_TYPES_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn
deleted file mode 100644
index 4a086190a90..00000000000
--- a/chromium/third_party/openscreen/src/cast/common/channel/BUILD.gn
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2019 The Chromium Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-
-source_set("channel") {
- sources = [
- "cast_message_handler.h",
- "cast_socket.cc",
- "cast_socket.h",
- "message_framer.cc",
- "message_framer.h",
- "message_util.h",
- "virtual_connection.h",
- "virtual_connection_manager.cc",
- "virtual_connection_manager.h",
- ]
-
- public_deps = [
- "../../../platform",
- "../../../third_party/abseil",
- ]
-
- deps = [
- "../../../util",
- "proto",
- ]
-}
-
-source_set("test") {
- testonly = true
- sources = [
- "test/fake_cast_socket.h",
- "test/mock_cast_message_handler.h",
- ]
-
- public_deps = [
- ":channel",
- "../../../platform:test",
- "../../../third_party/googletest:gmock",
- ]
-}
-
-source_set("unittests") {
- testonly = true
- sources = [
- "cast_socket_unittest.cc",
- "message_framer_unittest.cc",
- "virtual_connection_manager_unittest.cc",
- ]
-
- deps = [
- ":channel",
- ":test",
- "../../../platform",
- "../../../platform:test",
- "../../../third_party/googletest:gmock",
- "../../../third_party/googletest:gtest",
- "../../../util",
- "../../common/certificate/proto:unittest_proto",
- "proto",
- ]
-}
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h
index 9754b0c80d8..cd0d13e690d 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_message_handler.h
@@ -5,11 +5,12 @@
#ifndef CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_
#define CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace openscreen {
namespace cast {
-namespace channel {
class CastSocket;
-class CastMessage;
class VirtualConnectionRouter;
class CastMessageHandler {
@@ -18,10 +19,10 @@ class CastMessageHandler {
virtual void OnMessage(VirtualConnectionRouter* router,
CastSocket* socket,
- CastMessage&& message) = 0;
+ ::cast::channel::CastMessage message) = 0;
};
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_CAST_MESSAGE_HANDLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc
index 5b1c9408be9..ba7996584ac 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.cc
@@ -4,29 +4,20 @@
#include "cast/common/channel/cast_socket.h"
-#include <atomic>
-
#include "cast/common/channel/message_framer.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+using ::cast::channel::CastMessage;
using message_serialization::DeserializeResult;
-using openscreen::ErrorOr;
-using openscreen::platform::TlsConnection;
-
-uint32_t GetNextSocketId() {
- static std::atomic<uint32_t> id(1);
- return id++;
-}
CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
- Client* client,
- uint32_t socket_id)
- : client_(client),
- connection_(std::move(connection)),
- socket_id_(socket_id) {
+ Client* client)
+ : connection_(std::move(connection)),
+ client_(client),
+ socket_id_(g_next_socket_id_++) {
OSP_DCHECK(client);
connection_->SetClient(this);
}
@@ -46,12 +37,9 @@ Error CastSocket::SendMessage(const CastMessage& message) {
return out.error();
}
- if (state_ == State::kBlocked) {
- message_queue_.emplace_back(std::move(out.value()));
- return Error::Code::kNone;
+ if (!connection_->Send(out.value().data(), out.value().size())) {
+ return Error::Code::kAgain;
}
-
- connection_->Write(out.value().data(), out.value().size());
return Error::Code::kNone;
}
@@ -60,27 +48,20 @@ void CastSocket::SetClient(Client* client) {
client_ = client;
}
-void CastSocket::OnWriteBlocked(TlsConnection* connection) {
- if (state_ == State::kOpen) {
- state_ = State::kBlocked;
+std::array<uint8_t, 2> CastSocket::GetSanitizedIpAddress() {
+ IPEndpoint remote = connection_->GetRemoteEndpoint();
+ std::array<uint8_t, 2> result;
+ uint8_t bytes[16];
+ if (remote.address.IsV4()) {
+ remote.address.CopyToV4(bytes);
+ result[0] = bytes[2];
+ result[1] = bytes[3];
+ } else {
+ remote.address.CopyToV6(bytes);
+ result[0] = bytes[14];
+ result[1] = bytes[15];
}
-}
-
-void CastSocket::OnWriteUnblocked(TlsConnection* connection) {
- if (state_ != State::kBlocked) {
- return;
- }
- state_ = State::kOpen;
-
- // Attempt to write all messages that have been queued-up while the socket was
- // blocked. Stop if the socket becomes blocked again, or an error occurs.
- auto it = message_queue_.begin();
- for (const auto end = message_queue_.end();
- it != end && state_ == State::kOpen; ++it) {
- // The following Write() could transition |state_| to kBlocked or kError.
- connection_->Write(it->data(), it->size());
- }
- message_queue_.erase(message_queue_.begin(), it);
+ return result;
}
void CastSocket::OnError(TlsConnection* connection, Error error) {
@@ -101,5 +82,7 @@ void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) {
client_->OnMessage(this, std::move(message_or_error.value().message));
}
-} // namespace channel
+int CastSocket::g_next_socket_id_ = 1;
+
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h
index 6bd099b2f85..550aa76ed21 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket.h
@@ -5,19 +5,15 @@
#ifndef CAST_COMMON_CHANNEL_CAST_SOCKET_H_
#define CAST_COMMON_CHANNEL_CAST_SOCKET_H_
+#include <array>
+#include <memory>
#include <vector>
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "platform/api/tls_connection.h"
+namespace openscreen {
namespace cast {
-namespace channel {
-
-using openscreen::Error;
-using TlsConnection = openscreen::platform::TlsConnection;
-
-class CastMessage;
-
-uint32_t GetNextSocketId();
// Represents a simple message-oriented socket for communicating with the Cast
// V2 protocol. It isn't thread-safe, so it should only be used on the same
@@ -31,46 +27,49 @@ class CastSocket : public TlsConnection::Client {
// Called when a terminal error on |socket| has occurred.
virtual void OnError(CastSocket* socket, Error error) = 0;
- virtual void OnMessage(CastSocket* socket, CastMessage message) = 0;
+ virtual void OnMessage(CastSocket* socket,
+ ::cast::channel::CastMessage message) = 0;
};
- CastSocket(std::unique_ptr<TlsConnection> connection,
- Client* client,
- uint32_t socket_id);
+ CastSocket(std::unique_ptr<TlsConnection> connection, Client* client);
~CastSocket();
// Sends |message| immediately unless the underlying TLS connection is
// write-blocked, in which case |message| will be queued. An error will be
// returned if |message| cannot be serialized for any reason, even while
// write-blocked.
- Error SendMessage(const CastMessage& message);
+ [[nodiscard]] Error SendMessage(const ::cast::channel::CastMessage& message);
void SetClient(Client* client);
- uint32_t socket_id() const { return socket_id_; }
+ std::array<uint8_t, 2> GetSanitizedIpAddress();
+
+ int socket_id() const { return socket_id_; }
+
+ void set_audio_only(bool audio_only) { audio_only_ = audio_only; }
+ bool audio_only() const { return audio_only_; }
// TlsConnection::Client overrides.
- void OnWriteBlocked(TlsConnection* connection) override;
- void OnWriteUnblocked(TlsConnection* connection) override;
void OnError(TlsConnection* connection, Error error) override;
void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override;
private:
- enum class State {
- kOpen,
- kBlocked,
- kError,
+ enum class State : bool {
+ kOpen = true,
+ kError = false,
};
- Client* client_; // May never be null.
+ static int g_next_socket_id_;
+
const std::unique_ptr<TlsConnection> connection_;
+ Client* client_; // May never be null.
+ const int socket_id_;
+ bool audio_only_ = false;
std::vector<uint8_t> read_buffer_;
- const uint32_t socket_id_;
State state_ = State::kOpen;
- std::vector<std::vector<uint8_t>> message_queue_;
};
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_CAST_SOCKET_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc
index 35b157422e9..44776384305 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/cast_socket_unittest.cc
@@ -6,18 +6,20 @@
#include "cast/common/channel/message_framer.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
-#include "cast/common/channel/test/fake_cast_socket.h"
+#include "cast/common/channel/testing/fake_cast_socket.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+
+using ::cast::channel::CastMessage;
+
namespace {
using ::testing::_;
using ::testing::Invoke;
-
-using openscreen::ErrorOr;
+using ::testing::Return;
class CastSocketTest : public ::testing::Test {
public:
@@ -47,16 +49,36 @@ class CastSocketTest : public ::testing::Test {
} // namespace
TEST_F(CastSocketTest, SendMessage) {
- EXPECT_CALL(connection(), Write(_, _))
+ EXPECT_CALL(connection(), Send(_, _))
.WillOnce(Invoke([this](const void* data, size_t len) {
EXPECT_EQ(
frame_serial_,
std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
reinterpret_cast<const uint8_t*>(data) + len));
+ return true;
}));
ASSERT_TRUE(socket().SendMessage(message_).ok());
}
+TEST_F(CastSocketTest, SendMessageEventuallyBlocks) {
+ EXPECT_CALL(connection(), Send(_, _))
+ .Times(3)
+ .WillRepeatedly(Invoke([this](const void* data, size_t len) {
+ EXPECT_EQ(
+ frame_serial_,
+ std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ return true;
+ }))
+ .RetiresOnSaturation();
+ ASSERT_TRUE(socket().SendMessage(message_).ok());
+ ASSERT_TRUE(socket().SendMessage(message_).ok());
+ ASSERT_TRUE(socket().SendMessage(message_).ok());
+
+ EXPECT_CALL(connection(), Send(_, _)).WillOnce(Return(false));
+ ASSERT_EQ(socket().SendMessage(message_).code(), Error::Code::kAgain);
+}
+
TEST_F(CastSocketTest, ReadCompleteMessage) {
const uint8_t* data = frame_serial_.data();
EXPECT_CALL(mock_client(), OnMessage(_, _))
@@ -99,43 +121,19 @@ TEST_F(CastSocketTest, ReadChunkedMessage) {
data + double_message.size()));
}
-TEST_F(CastSocketTest, SendMessageWhileBlocked) {
- connection().OnWriteBlocked();
- EXPECT_CALL(connection(), Write(_, _)).Times(0);
- ASSERT_TRUE(socket().SendMessage(message_).ok());
-
- EXPECT_CALL(connection(), Write(_, _))
- .WillOnce(Invoke([this](const void* data, size_t len) {
- EXPECT_EQ(
- frame_serial_,
- std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
- reinterpret_cast<const uint8_t*>(data) + len));
- }));
- connection().OnWriteUnblocked();
-
- EXPECT_CALL(connection(), Write(_, _)).Times(0);
- connection().OnWriteBlocked();
- connection().OnWriteUnblocked();
-}
-
-TEST_F(CastSocketTest, ErrorWhileEmptyingQueue) {
- connection().OnWriteBlocked();
- EXPECT_CALL(connection(), Write(_, _)).Times(0);
- ASSERT_TRUE(socket().SendMessage(message_).ok());
-
- EXPECT_CALL(connection(), Write(_, _))
- .WillOnce(Invoke([this](const void* data, size_t len) {
- EXPECT_EQ(
- frame_serial_,
- std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
- reinterpret_cast<const uint8_t*>(data) + len));
- connection().OnError(Error::Code::kUnknownError);
- }));
- connection().OnWriteUnblocked();
-
- EXPECT_CALL(connection(), Write(_, _)).Times(0);
- ASSERT_FALSE(socket().SendMessage(message_).ok());
+TEST_F(CastSocketTest, SanitizedAddress) {
+ std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress();
+ EXPECT_EQ(result1[0], 1u);
+ EXPECT_EQ(result1[1], 9u);
+
+ FakeCastSocket v6_socket(IPEndpoint{{1, 2, 3, 4}, 1025},
+ IPEndpoint{{0x1819, 0x1a1b, 0x1c1d, 0x1e1f, 0x207b,
+ 0x7c7d, 0x7e7f, 0x8081},
+ 4321});
+ std::array<uint8_t, 2> result2 = v6_socket.socket.GetSanitizedIpAddress();
+ EXPECT_EQ(result2[0], 128);
+ EXPECT_EQ(result2[1], 129);
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc
new file mode 100644
index 00000000000..40b2b84038b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.cc
@@ -0,0 +1,259 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/connection_namespace_handler.h"
+
+#include <type_traits>
+
+#include "absl/types/optional.h"
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/virtual_connection.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "util/json/json_serialization.h"
+#include "util/json/json_value.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+using ::cast::channel::CastMessage;
+using ::cast::channel::CastMessage_PayloadType;
+
+namespace {
+
+bool IsValidProtocolVersion(int version) {
+ return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version);
+}
+
+absl::optional<int> FindMaxProtocolVersion(const Json::Value* version,
+ const Json::Value* version_list) {
+ using ArrayIndex = Json::Value::ArrayIndex;
+ static_assert(std::is_integral<ArrayIndex>::value,
+ "Assuming ArrayIndex is integral");
+ absl::optional<int> max_version;
+ if (version_list && version_list->isArray()) {
+ max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
+ for (auto it = version_list->begin(), end = version_list->end(); it != end;
+ ++it) {
+ if (it->isInt()) {
+ int version_int = it->asInt();
+ if (IsValidProtocolVersion(version_int) && version_int > *max_version) {
+ max_version = version_int;
+ }
+ }
+ }
+ }
+ if (version && version->isInt()) {
+ int version_int = version->asInt();
+ if (IsValidProtocolVersion(version_int)) {
+ if (!max_version) {
+ max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
+ }
+ if (version_int > max_version) {
+ max_version = version_int;
+ }
+ }
+ }
+ return max_version;
+}
+
+VirtualConnection::CloseReason GetCloseReason(
+ const Json::Value& parsed_message) {
+ VirtualConnection::CloseReason reason =
+ VirtualConnection::CloseReason::kClosedByPeer;
+ absl::optional<int> reason_code = MaybeGetInt(
+ parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode));
+ if (reason_code) {
+ int code = reason_code.value();
+ if (code >= VirtualConnection::CloseReason::kFirstReason &&
+ code <= VirtualConnection::CloseReason::kLastReason) {
+ reason = static_cast<VirtualConnection::CloseReason>(code);
+ }
+ }
+ return reason;
+}
+
+} // namespace
+
+ConnectionNamespaceHandler::ConnectionNamespaceHandler(
+ VirtualConnectionManager* vc_manager,
+ VirtualConnectionPolicy* vc_policy)
+ : vc_manager_(vc_manager), vc_policy_(vc_policy) {
+ OSP_DCHECK(vc_manager);
+ OSP_DCHECK(vc_policy);
+}
+
+ConnectionNamespaceHandler::~ConnectionNamespaceHandler() = default;
+
+void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ CastMessage message) {
+ if (message.payload_type() !=
+ CastMessage_PayloadType::CastMessage_PayloadType_STRING) {
+ return;
+ }
+ ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
+ if (result.is_error()) {
+ return;
+ }
+
+ Json::Value& value = result.value();
+ if (!value.isObject()) {
+ return;
+ }
+
+ absl::optional<absl::string_view> type =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
+ if (!type) {
+ // TODO(btolsch): Some of these paths should have error reporting. One
+ // possibility is to pass errors back through |router| so higher-level code
+ // can decide whether to show an error to the user, stop talking to a
+ // particular device, etc.
+ return;
+ }
+
+ absl::string_view type_str = type.value();
+ if (type_str == kMessageTypeConnect) {
+ HandleConnect(router, socket, std::move(message), std::move(value));
+ } else if (type_str == kMessageTypeClose) {
+ HandleClose(router, socket, std::move(message), std::move(value));
+ } else {
+ // NOTE: Unknown message type so ignore it.
+ // TODO(btolsch): Should be included in future error reporting.
+ }
+}
+
+void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ CastMessage message,
+ Json::Value parsed_message) {
+ if (message.destination_id() == kBroadcastId ||
+ message.source_id() == kBroadcastId) {
+ return;
+ }
+
+ VirtualConnection virtual_conn{std::move(message.destination_id()),
+ std::move(message.source_id()),
+ socket->socket_id()};
+ if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
+ SendClose(router, std::move(virtual_conn));
+ return;
+ }
+
+ absl::optional<int> maybe_conn_type = MaybeGetInt(
+ parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType));
+ VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong;
+ if (maybe_conn_type) {
+ int int_type = maybe_conn_type.value();
+ if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) ||
+ int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) {
+ SendClose(router, std::move(virtual_conn));
+ return;
+ }
+ conn_type = static_cast<VirtualConnection::Type>(int_type);
+ }
+
+ VirtualConnection::AssociatedData data;
+
+ data.type = conn_type;
+
+ absl::optional<absl::string_view> user_agent = MaybeGetString(
+ parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent));
+ if (user_agent) {
+ data.user_agent = std::string(user_agent.value());
+ }
+
+ const Json::Value* sender_info_value = parsed_message.find(
+ JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo));
+ if (!sender_info_value || !sender_info_value->isObject()) {
+ // TODO(btolsch): Should this be guessed from user agent?
+ OSP_DVLOG << "No sender info from protocol.";
+ }
+
+ const Json::Value* version_value = parsed_message.find(
+ JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
+ const Json::Value* version_list_value = parsed_message.find(
+ JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList));
+ absl::optional<int> negotiated_version =
+ FindMaxProtocolVersion(version_value, version_list_value);
+ if (negotiated_version) {
+ data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>(
+ negotiated_version.value());
+ } else {
+ data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0;
+ }
+
+ data.ip_fragment = socket->GetSanitizedIpAddress();
+
+ OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", "
+ << virtual_conn.peer_id << ", " << virtual_conn.socket_id;
+
+ // NOTE: Only send a response for senders that actually sent a version. This
+ // maintains compatibility with older senders that don't send a version and
+ // don't expect a response.
+ if (negotiated_version) {
+ SendConnectedResponse(router, virtual_conn, negotiated_version.value());
+ }
+
+ vc_manager_->AddConnection(std::move(virtual_conn), std::move(data));
+}
+
+void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ CastMessage message,
+ Json::Value parsed_message) {
+ VirtualConnection virtual_conn{std::move(message.destination_id()),
+ std::move(message.source_id()),
+ socket->socket_id()};
+ if (!vc_manager_->GetConnectionData(virtual_conn)) {
+ return;
+ }
+
+ VirtualConnection::CloseReason reason = GetCloseReason(parsed_message);
+
+ OSP_DVLOG << "Connection closed (reason: " << reason
+ << "): " << virtual_conn.local_id << ", " << virtual_conn.peer_id
+ << ", " << virtual_conn.socket_id;
+ vc_manager_->RemoveConnection(virtual_conn, reason);
+}
+
+void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router,
+ VirtualConnection virtual_conn) {
+ Json::Value close_message(Json::ValueType::objectValue);
+ close_message[kMessageKeyType] = kMessageTypeClose;
+
+ ErrorOr<std::string> result = json::Stringify(close_message);
+ if (result.is_error()) {
+ return;
+ }
+
+ router->SendMessage(
+ std::move(virtual_conn),
+ MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
+}
+
+void ConnectionNamespaceHandler::SendConnectedResponse(
+ VirtualConnectionRouter* router,
+ const VirtualConnection& virtual_conn,
+ int max_protocol_version) {
+ Json::Value connected_message(Json::ValueType::objectValue);
+ connected_message[kMessageKeyType] = kMessageTypeConnected;
+ connected_message[kMessageKeyProtocolVersion] =
+ static_cast<int>(max_protocol_version);
+
+ ErrorOr<std::string> result = json::Stringify(connected_message);
+ if (result.is_error()) {
+ return;
+ }
+
+ router->SendMessage(
+ virtual_conn,
+ MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h
new file mode 100644
index 00000000000..5307e896893
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler.h
@@ -0,0 +1,64 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_
+#define CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_
+
+#include "cast/common/channel/cast_message_handler.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "util/json/json_serialization.h"
+
+namespace openscreen {
+namespace cast {
+
+struct VirtualConnection;
+class VirtualConnectionManager;
+class VirtualConnectionRouter;
+
+// Handles CastMessages in the connection namespace by opening and closing
+// VirtualConnections on the socket on which the messages were received.
+class ConnectionNamespaceHandler final : public CastMessageHandler {
+ public:
+ class VirtualConnectionPolicy {
+ public:
+ virtual ~VirtualConnectionPolicy() = default;
+
+ virtual bool IsConnectionAllowed(
+ const VirtualConnection& virtual_conn) const = 0;
+ };
+
+ // Both |vc_manager| and |vc_policy| should outlive this object.
+ ConnectionNamespaceHandler(VirtualConnectionManager* vc_manager,
+ VirtualConnectionPolicy* vc_policy);
+ ~ConnectionNamespaceHandler() override;
+
+ // CastMessageHandler overrides.
+ void OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
+
+ private:
+ void HandleConnect(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message,
+ Json::Value parsed_message);
+ void HandleClose(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message,
+ Json::Value parsed_message);
+
+ void SendClose(VirtualConnectionRouter* router,
+ VirtualConnection virtual_conn);
+ void SendConnectedResponse(VirtualConnectionRouter* router,
+ const VirtualConnection& virtual_conn,
+ int max_protocol_version);
+
+ VirtualConnectionManager* const vc_manager_;
+ VirtualConnectionPolicy* const vc_policy_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CHANNEL_CONNECTION_NAMESPACE_HANDLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc
new file mode 100644
index 00000000000..0ff0763b093
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/connection_namespace_handler_unittest.cc
@@ -0,0 +1,226 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/connection_namespace_handler.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_socket_error_handler.h"
+#include "cast/common/channel/virtual_connection.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "util/json/json_serialization.h"
+#include "util/json/json_value.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::NiceMock;
+
+using ::cast::channel::CastMessage;
+using ::cast::channel::CastMessage_ProtocolVersion;
+
+class MockVirtualConnectionPolicy
+ : public ConnectionNamespaceHandler::VirtualConnectionPolicy {
+ public:
+ ~MockVirtualConnectionPolicy() override = default;
+
+ MOCK_METHOD(bool,
+ IsConnectionAllowed,
+ (const VirtualConnection& virtual_conn),
+ (const, override));
+};
+
+CastMessage MakeVersionedConnectMessage(
+ const std::string& source_id,
+ const std::string& destination_id,
+ absl::optional<CastMessage_ProtocolVersion> version,
+ std::vector<CastMessage_ProtocolVersion> version_list) {
+ CastMessage connect_message = MakeConnectMessage(source_id, destination_id);
+ Json::Value message(Json::ValueType::objectValue);
+ message[kMessageKeyType] = kMessageTypeConnect;
+ if (version) {
+ message[kMessageKeyProtocolVersion] = version.value();
+ }
+ if (!version_list.empty()) {
+ Json::Value list(Json::ValueType::arrayValue);
+ for (CastMessage_ProtocolVersion v : version_list) {
+ list.append(v);
+ }
+ message[kMessageKeyProtocolVersionList] = std::move(list);
+ }
+ ErrorOr<std::string> result = json::Stringify(message);
+ OSP_DCHECK(result);
+ connect_message.set_payload_utf8(std::move(result.value()));
+ return connect_message;
+}
+
+void VerifyConnectionMessage(const CastMessage& message,
+ const std::string& source_id,
+ const std::string& destination_id) {
+ EXPECT_EQ(message.source_id(), source_id);
+ EXPECT_EQ(message.destination_id(), destination_id);
+ EXPECT_EQ(message.namespace_(), kConnectionNamespace);
+ ASSERT_EQ(message.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_STRING);
+}
+
+Json::Value ParseConnectionMessage(const CastMessage& message) {
+ ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
+ OSP_CHECK(result) << message.payload_utf8();
+ return result.value();
+}
+
+} // namespace
+
+class ConnectionNamespaceHandlerTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ socket_ = fake_cast_socket_pair_.socket.get();
+ router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.socket));
+
+ ON_CALL(vc_policy_, IsConnectionAllowed(_))
+ .WillByDefault(
+ Invoke([](const VirtualConnection& virtual_conn) { return true; }));
+ }
+
+ protected:
+ void ExpectCloseMessage(MockCastSocketClient* mock_client,
+ const std::string& source_id,
+ const std::string& destination_id) {
+ EXPECT_CALL(*mock_client, OnMessage(_, _))
+ .WillOnce(Invoke([&source_id, &destination_id](CastSocket* socket,
+ CastMessage message) {
+ VerifyConnectionMessage(message, source_id, destination_id);
+ Json::Value value = ParseConnectionMessage(message);
+ absl::optional<absl::string_view> type = MaybeGetString(
+ value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
+ ASSERT_TRUE(type) << message.payload_utf8();
+ EXPECT_EQ(type.value(), kMessageTypeClose) << message.payload_utf8();
+ }));
+ }
+
+ void ExpectConnectedMessage(
+ MockCastSocketClient* mock_client,
+ const std::string& source_id,
+ const std::string& destination_id,
+ absl::optional<CastMessage_ProtocolVersion> version = absl::nullopt) {
+ EXPECT_CALL(*mock_client, OnMessage(_, _))
+ .WillOnce(Invoke([&source_id, &destination_id, version](
+ CastSocket* socket, CastMessage message) {
+ VerifyConnectionMessage(message, source_id, destination_id);
+ Json::Value value = ParseConnectionMessage(message);
+ absl::optional<absl::string_view> type = MaybeGetString(
+ value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
+ ASSERT_TRUE(type) << message.payload_utf8();
+ EXPECT_EQ(type.value(), kMessageTypeConnected)
+ << message.payload_utf8();
+ if (version) {
+ absl::optional<int> message_version = MaybeGetInt(
+ value,
+ JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
+ ASSERT_TRUE(message_version) << message.payload_utf8();
+ EXPECT_EQ(message_version.value(), version.value());
+ }
+ }));
+ }
+
+ FakeCastSocketPair fake_cast_socket_pair_;
+ MockSocketErrorHandler mock_error_handler_;
+ CastSocket* socket_;
+
+ NiceMock<MockVirtualConnectionPolicy> vc_policy_;
+ VirtualConnectionManager vc_manager_;
+ VirtualConnectionRouter router_{&vc_manager_};
+ ConnectionNamespaceHandler connection_namespace_handler_{&vc_manager_,
+ &vc_policy_};
+
+ const std::string sender_id_{"sender-5678"};
+ const std::string receiver_id_{"receiver-3245"};
+};
+
+TEST_F(ConnectionNamespaceHandlerTest, Connect) {
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+
+ EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
+ .Times(0);
+}
+
+TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) {
+ EXPECT_CALL(vc_policy_, IsConnectionAllowed(_))
+ .WillOnce(
+ Invoke([](const VirtualConnection& virtual_conn) { return false; }));
+ ExpectCloseMessage(&fake_cast_socket_pair_.mock_peer_client, receiver_id_,
+ sender_id_);
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
+ EXPECT_FALSE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+}
+
+TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) {
+ ExpectConnectedMessage(
+ &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2);
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_,
+ MakeVersionedConnectMessage(
+ sender_id_, receiver_id_,
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {}));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+}
+
+TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) {
+ ExpectConnectedMessage(
+ &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3);
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_,
+ MakeVersionedConnectMessage(
+ sender_id_, receiver_id_,
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2,
+ {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3,
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0}));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+}
+
+TEST_F(ConnectionNamespaceHandlerTest, Close) {
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_));
+ EXPECT_FALSE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+}
+
+TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) {
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+
+ connection_namespace_handler_.OnMessage(
+ &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_));
+ EXPECT_TRUE(vc_manager_.GetConnectionData(
+ VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc
index 2494a6ea67c..a8406539d3b 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.cc
@@ -13,12 +13,10 @@
#include "util/big_endian.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace channel {
namespace message_serialization {
-using openscreen::Error;
-
namespace {
static constexpr size_t kHeaderSize = sizeof(uint32_t);
@@ -28,26 +26,26 @@ static constexpr size_t kMaxBodySize = 65536;
} // namespace
-ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message) {
+ErrorOr<std::vector<uint8_t>> Serialize(
+ const ::cast::channel::CastMessage& message) {
const size_t message_size = message.ByteSizeLong();
if (message_size > kMaxBodySize || message_size == 0) {
return Error::Code::kCastV2InvalidMessage;
}
std::vector<uint8_t> out(message_size + kHeaderSize, 0);
- openscreen::WriteBigEndian<uint32_t>(message_size, out.data());
+ WriteBigEndian<uint32_t>(message_size, out.data());
if (!message.SerializeToArray(&out[kHeaderSize], message_size)) {
return Error::Code::kCastV2InvalidMessage;
}
return out;
}
-ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) {
+ErrorOr<DeserializeResult> TryDeserialize(absl::Span<const uint8_t> input) {
if (input.size() < kHeaderSize) {
return Error::Code::kInsufficientBuffer;
}
- const uint32_t message_size =
- openscreen::ReadBigEndian<uint32_t>(input.data());
+ const uint32_t message_size = ReadBigEndian<uint32_t>(input.data());
if (message_size > kMaxBodySize) {
return Error::Code::kCastV2InvalidMessage;
}
@@ -67,5 +65,5 @@ ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input) {
}
} // namespace message_serialization
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h
index c092487cb1d..3de1cb7594d 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer.h
@@ -15,18 +15,17 @@
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "platform/base/error.h"
+namespace openscreen {
namespace cast {
-namespace channel {
namespace message_serialization {
-using openscreen::ErrorOr;
-
// Serializes |message_proto| into |message_data|.
// Returns true if the message was serialized successfully, false otherwise.
-ErrorOr<std::vector<uint8_t>> Serialize(const CastMessage& message);
+ErrorOr<std::vector<uint8_t>> Serialize(
+ const ::cast::channel::CastMessage& message);
struct DeserializeResult {
- CastMessage message;
+ ::cast::channel::CastMessage message;
size_t length;
};
@@ -34,10 +33,10 @@ struct DeserializeResult {
// read. Returns a parsed CastMessage if a message was received in its
// entirety, and an error otherwise. The result also contains the number of
// bytes consumed from |input| when a parse succeeds.
-ErrorOr<DeserializeResult> TryDeserialize(absl::Span<uint8_t> input);
+ErrorOr<DeserializeResult> TryDeserialize(absl::Span<const uint8_t> input);
} // namespace message_serialization
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_MESSAGE_FRAMER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc
new file mode 100644
index 00000000000..53678e748a0
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer.cc
@@ -0,0 +1,14 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <cstddef>
+#include <cstdint>
+
+#include "cast/common/channel/message_framer.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ openscreen::cast::message_serialization::TryDeserialize(
+ absl::Span<const uint8_t>(data, size));
+ return 0;
+}
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0
new file mode 100644
index 00000000000..41fd475902d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/03d4b4028b559489768e2cccd6015c907f70a2c0
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0
new file mode 100644
index 00000000000..ab09cd27a30
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/333be5dfffb2c6eeadf31be2dc219ef841c99ea0
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4
new file mode 100644
index 00000000000..fc53faf34ba
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/b03aaebaa88ca4f4b8d63c7a63fc55ba402cfbb4
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1
new file mode 100644
index 00000000000..05f1e12bd3e
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len1
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2
new file mode 100644
index 00000000000..6f745b25c75
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_len2
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto
new file mode 100644
index 00000000000..7dd4315ed6a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/bad_proto
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c
new file mode 100644
index 00000000000..445ded6c8b4
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/cf93596ce5bbb0d4c91f3ee493e01f0674d36c0c
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6
new file mode 100644
index 00000000000..4e15a3d7d4f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e14c401475d86e0f279691c168c7122ceb77c2c6
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7 b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7
new file mode 100644
index 00000000000..5dc9591749f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_fuzzer_seeds/e9b451d1575019d52e0e072ce5b22a2418d237c7
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc
index ae2ff33e698..e70459e67ec 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_framer_unittest.cc
@@ -14,11 +14,11 @@
#include "util/big_endian.h"
#include "util/std_util.h"
+namespace openscreen {
namespace cast {
-namespace channel {
namespace message_serialization {
-using openscreen::Error;
+using ::cast::channel::CastMessage;
namespace {
@@ -149,5 +149,5 @@ TEST_F(CastFramerTest, TestUnparsableBodyProto) {
}
} // namespace message_serialization
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc
new file mode 100644
index 00000000000..ee9a91f1309
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_util.cc
@@ -0,0 +1,71 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/message_util.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using ::cast::channel::CastMessage;
+
+CastMessage MakeConnectionMessage(const std::string& source_id,
+ const std::string& destination_id) {
+ CastMessage connect_message;
+ connect_message.set_protocol_version(kDefaultOutgoingMessageVersion);
+ connect_message.set_source_id(source_id);
+ connect_message.set_destination_id(destination_id);
+ connect_message.set_namespace_(kConnectionNamespace);
+ return connect_message;
+}
+
+} // namespace
+
+std::string ToString(AppAvailabilityResult availability) {
+ switch (availability) {
+ case AppAvailabilityResult::kAvailable:
+ return "Available";
+ case AppAvailabilityResult::kUnavailable:
+ return "Unavailable";
+ case AppAvailabilityResult::kUnknown:
+ return "Unknown";
+ default:
+ OSP_NOTREACHED();
+ return "bad value";
+ }
+}
+
+CastMessage MakeSimpleUTF8Message(const std::string& namespace_,
+ std::string payload) {
+ CastMessage message;
+ message.set_protocol_version(kDefaultOutgoingMessageVersion);
+ message.set_namespace_(namespace_);
+ message.set_payload_type(::cast::channel::CastMessage_PayloadType_STRING);
+ message.set_payload_utf8(std::move(payload));
+ return message;
+}
+
+CastMessage MakeConnectMessage(const std::string& source_id,
+ const std::string& destination_id) {
+ CastMessage connect_message =
+ MakeConnectionMessage(source_id, destination_id);
+ connect_message.set_payload_type(
+ ::cast::channel::CastMessage_PayloadType_STRING);
+ connect_message.set_payload_utf8(R"!({"type": "CONNECT"})!");
+ return connect_message;
+}
+
+CastMessage MakeCloseMessage(const std::string& source_id,
+ const std::string& destination_id) {
+ CastMessage close_message = MakeConnectionMessage(source_id, destination_id);
+ close_message.set_payload_type(
+ ::cast::channel::CastMessage_PayloadType_STRING);
+ close_message.set_payload_utf8(R"!({"type": "CLOSE"})!");
+ return close_message;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/message_util.h b/chromium/third_party/openscreen/src/cast/common/channel/message_util.h
index 990fd55c5f8..f1ba2ead412 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/message_util.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/message_util.h
@@ -8,8 +8,8 @@
#include "absl/strings/string_view.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
+namespace openscreen {
namespace cast {
-namespace channel {
// Reserved message namespaces for internal messages.
static constexpr char kCastInternalNamespacePrefix[] =
@@ -32,7 +32,153 @@ static constexpr char kMediaNamespace[] = "urn:x-cast:com.google.cast.media";
static constexpr char kPlatformSenderId[] = "sender-0";
static constexpr char kPlatformReceiverId[] = "receiver-0";
-inline bool IsAuthMessage(const CastMessage& message) {
+static constexpr char kBroadcastId[] = "*";
+
+static constexpr ::cast::channel::CastMessage_ProtocolVersion
+ kDefaultOutgoingMessageVersion =
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
+
+// JSON message key strings.
+static constexpr char kMessageKeyType[] = "type";
+static constexpr char kMessageKeyConnType[] = "connType";
+static constexpr char kMessageKeyUserAgent[] = "userAgent";
+static constexpr char kMessageKeySenderInfo[] = "senderInfo";
+static constexpr char kMessageKeyProtocolVersion[] = "protocolVersion";
+static constexpr char kMessageKeyProtocolVersionList[] = "protocolVersionList";
+static constexpr char kMessageKeyReasonCode[] = "reasonCode";
+static constexpr char kMessageKeyAppId[] = "appId";
+static constexpr char kMessageKeyRequestId[] = "requestId";
+static constexpr char kMessageKeyAvailability[] = "availability";
+
+// JSON message field values.
+static constexpr char kMessageTypeConnect[] = "CONNECT";
+static constexpr char kMessageTypeClose[] = "CLOSE";
+static constexpr char kMessageTypeConnected[] = "CONNECTED";
+static constexpr char kMessageValueAppAvailable[] = "APP_AVAILABLE";
+static constexpr char kMessageValueAppUnavailable[] = "APP_UNAVAILABLE";
+
+// TODO(crbug.com/openscreen/111): Add validation that each message type is
+// received on the correct namespace. This will probably involve creating a
+// data structure for mapping between type and namespace.
+enum class CastMessageType {
+ // Heartbeat messages.
+ kPing,
+ kPong,
+
+ // RPC control/status messages used by Media Remoting. These occur at high
+ // frequency, up to dozens per second at times, and should not be logged.
+ kRpc,
+
+ kGetAppAvailability,
+ kGetStatus,
+
+ // Virtual connection request.
+ kConnect,
+
+ // Close virtual connection.
+ kCloseConnection,
+
+ // Application broadcast / precache.
+ kBroadcast,
+
+ // Session launch request.
+ kLaunch,
+
+ // Session stop request.
+ kStop,
+
+ kReceiverStatus,
+ kMediaStatus,
+
+ // Error from receiver.
+ kLaunchError,
+
+ kOffer,
+ kAnswer,
+ kCapabilitiesResponse,
+ kStatusResponse,
+
+ // The following values are part of the protocol but are not currently used.
+ kMultizoneStatus,
+ kInvalidPlayerState,
+ kLoadFailed,
+ kLoadCancelled,
+ kInvalidRequest,
+ kPresentation,
+ kGetCapabilities,
+
+ kOther, // Add new types above |kOther|.
+ kMaxValue = kOther,
+};
+
+enum class AppAvailabilityResult {
+ kAvailable,
+ kUnavailable,
+ kUnknown,
+};
+
+std::string ToString(AppAvailabilityResult availability);
+
+// TODO(crbug.com/openscreen/111): When this and/or other enums need the
+// string->enum mapping, import EnumTable from Chromium's
+// //components/cast_channel/enum_table.h.
+inline constexpr const char* CastMessageTypeToString(CastMessageType type) {
+ switch (type) {
+ case CastMessageType::kPing:
+ return "PING";
+ case CastMessageType::kPong:
+ return "PONG";
+ case CastMessageType::kRpc:
+ return "RPC";
+ case CastMessageType::kGetAppAvailability:
+ return "GET_APP_AVAILABILITY";
+ case CastMessageType::kGetStatus:
+ return "GET_STATUS";
+ case CastMessageType::kConnect:
+ return "CONNECT";
+ case CastMessageType::kCloseConnection:
+ return "CLOSE";
+ case CastMessageType::kBroadcast:
+ return "APPLICATION_BROADCAST";
+ case CastMessageType::kLaunch:
+ return "LAUNCH";
+ case CastMessageType::kStop:
+ return "STOP";
+ case CastMessageType::kReceiverStatus:
+ return "RECEIVER_STATUS";
+ case CastMessageType::kMediaStatus:
+ return "MEDIA_STATUS";
+ case CastMessageType::kLaunchError:
+ return "LAUNCH_ERROR";
+ case CastMessageType::kOffer:
+ return "OFFER";
+ case CastMessageType::kAnswer:
+ return "ANSWER";
+ case CastMessageType::kCapabilitiesResponse:
+ return "CAPABILITIES_RESPONSE";
+ case CastMessageType::kStatusResponse:
+ return "STATUS_RESPONSE";
+ case CastMessageType::kMultizoneStatus:
+ return "MULTIZONE_STATUS";
+ case CastMessageType::kInvalidPlayerState:
+ return "INVALID_PLAYER_STATE";
+ case CastMessageType::kLoadFailed:
+ return "LOAD_FAILED";
+ case CastMessageType::kLoadCancelled:
+ return "LOAD_CANCELLED";
+ case CastMessageType::kInvalidRequest:
+ return "INVALID_REQUEST";
+ case CastMessageType::kPresentation:
+ return "PRESENTATION";
+ case CastMessageType::kGetCapabilities:
+ return "GET_CAPABILITIES";
+ case CastMessageType::kOther:
+ default:
+ return "OTHER";
+ }
+}
+
+inline bool IsAuthMessage(const ::cast::channel::CastMessage& message) {
return message.namespace_() == kAuthNamespace;
}
@@ -41,7 +187,18 @@ inline bool IsTransportNamespace(absl::string_view namespace_) {
(namespace_.find_first_of(kTransportNamespacePrefix) == 0);
}
-} // namespace channel
+::cast::channel::CastMessage MakeSimpleUTF8Message(
+ const std::string& namespace_,
+ std::string payload);
+
+::cast::channel::CastMessage MakeConnectMessage(
+ const std::string& source_id,
+ const std::string& destination_id);
+::cast::channel::CastMessage MakeCloseMessage(
+ const std::string& source_id,
+ const std::string& destination_id);
+
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc
new file mode 100644
index 00000000000..041d8f3f0df
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.cc
@@ -0,0 +1,35 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/namespace_router.h"
+
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace openscreen {
+namespace cast {
+
+NamespaceRouter::NamespaceRouter() = default;
+NamespaceRouter::~NamespaceRouter() = default;
+
+void NamespaceRouter::AddNamespaceHandler(std::string namespace_,
+ CastMessageHandler* handler) {
+ handlers_.emplace(std::move(namespace_), handler);
+}
+
+void NamespaceRouter::RemoveNamespaceHandler(const std::string& namespace_) {
+ handlers_.erase(namespace_);
+}
+
+void NamespaceRouter::OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) {
+ const std::string& ns = message.namespace_();
+ auto it = handlers_.find(ns);
+ if (it != handlers_.end()) {
+ it->second->OnMessage(router, socket, std::move(message));
+ }
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h
new file mode 100644
index 00000000000..0b6b581c0e1
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router.h
@@ -0,0 +1,37 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_
+#define CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_
+
+#include <map>
+#include <string>
+
+#include "cast/common/channel/cast_message_handler.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace openscreen {
+namespace cast {
+
+class NamespaceRouter final : public CastMessageHandler {
+ public:
+ NamespaceRouter();
+ ~NamespaceRouter() override;
+
+ void AddNamespaceHandler(std::string namespace_, CastMessageHandler* handler);
+ void RemoveNamespaceHandler(const std::string& namespace_);
+
+ // CastMessageHandler overrides.
+ void OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
+
+ private:
+ std::map<std::string /* namespace */, CastMessageHandler*> handlers_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CHANNEL_NAMESPACE_ROUTER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc
new file mode 100644
index 00000000000..96907c081c0
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/namespace_router_unittest.cc
@@ -0,0 +1,98 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/channel/namespace_router.h"
+
+#include "cast/common/channel/cast_message_handler.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_cast_message_handler.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using ::cast::channel::CastMessage;
+using ::testing::_;
+using ::testing::Invoke;
+
+class NamespaceRouterTest : public ::testing::Test {
+ public:
+ protected:
+ CastSocket* socket() { return &fake_socket_.socket; }
+
+ FakeCastSocket fake_socket_;
+ VirtualConnectionManager vc_manager_;
+ VirtualConnectionRouter vc_router_{&vc_manager_};
+ NamespaceRouter router_;
+};
+
+} // namespace
+
+TEST_F(NamespaceRouterTest, NoHandlersNoop) {
+ CastMessage message;
+ message.set_namespace_("anzrfcnpr");
+ router_.OnMessage(&vc_router_, socket(), std::move(message));
+}
+
+TEST_F(NamespaceRouterTest, MultipleHandlers) {
+ MockCastMessageHandler media_handler;
+ MockCastMessageHandler auth_handler;
+ MockCastMessageHandler connection_handler;
+
+ router_.AddNamespaceHandler("media", &media_handler);
+ router_.AddNamespaceHandler("auth", &auth_handler);
+ router_.AddNamespaceHandler("connection", &connection_handler);
+
+ EXPECT_CALL(media_handler, OnMessage(_, _, _)).Times(0);
+ EXPECT_CALL(auth_handler, OnMessage(_, _, _))
+ .WillOnce(Invoke([](VirtualConnectionRouter* router, CastSocket*,
+ CastMessage message) {
+ EXPECT_EQ(message.namespace_(), "auth");
+ }));
+ EXPECT_CALL(connection_handler, OnMessage(_, _, _))
+ .WillOnce(Invoke([](VirtualConnectionRouter* router, CastSocket*,
+ CastMessage message) {
+ EXPECT_EQ(message.namespace_(), "connection");
+ }));
+
+ CastMessage auth_message;
+ auth_message.set_namespace_("auth");
+ router_.OnMessage(&vc_router_, socket(), std::move(auth_message));
+
+ CastMessage connection_message;
+ connection_message.set_namespace_("connection");
+ router_.OnMessage(&vc_router_, socket(), std::move(connection_message));
+}
+
+TEST_F(NamespaceRouterTest, RemoveHandler) {
+ MockCastMessageHandler handler1;
+ MockCastMessageHandler handler2;
+
+ router_.AddNamespaceHandler("one", &handler1);
+ router_.AddNamespaceHandler("two", &handler2);
+
+ router_.RemoveNamespaceHandler("one");
+
+ EXPECT_CALL(handler1, OnMessage(_, _, _)).Times(0);
+ EXPECT_CALL(handler2, OnMessage(_, _, _))
+ .WillOnce(Invoke(
+ [](VirtualConnectionRouter* router, CastSocket* socket,
+ CastMessage message) { EXPECT_EQ("two", message.namespace_()); }));
+
+ CastMessage message1;
+ message1.set_namespace_("one");
+ router_.OnMessage(&vc_router_, socket(), std::move(message1));
+
+ CastMessage message2;
+ message2.set_namespace_("two");
+ router_.OnMessage(&vc_router_, socket(), std::move(message2));
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto b/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto
index 5689e364a1f..1b99b8ee604 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto
+++ b/chromium/third_party/openscreen/src/cast/common/channel/proto/authority_keys.proto
@@ -6,7 +6,11 @@ syntax = "proto2";
option optimize_for = LITE_RUNTIME;
-package cast_channel.proto;
+// TODO(crbug.com/openscreen/90): Rename to openscreen.cast, to update to the
+// current namespacing of the library. Also, this file should probably be moved
+// to the public directory. And, all of this will have to be coordinated with a
+// DEPS roll in Chromium (since Chromium code depends on this).
+package cast.channel;
message AuthorityKeys {
message Key {
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto b/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto
index 57c7b3f3fb7..58fb6c4dc6a 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto
+++ b/chromium/third_party/openscreen/src/cast/common/channel/proto/cast_channel.proto
@@ -6,12 +6,21 @@ syntax = "proto2";
option optimize_for = LITE_RUNTIME;
+// TODO(crbug.com/openscreen/90): Rename to openscreen.cast, to update to the
+// current namespacing of the library. Also, this file should probably be moved
+// to the public directory. And, all of this will have to be coordinated with a
+// DEPS roll in Chromium (since Chromium code depends on this).
package cast.channel;
message CastMessage {
// Always pass a version of the protocol for future compatibility
// requirements.
- enum ProtocolVersion { CASTV2_1_0 = 0; }
+ enum ProtocolVersion {
+ CASTV2_1_0 = 0;
+ CASTV2_1_1 = 1; // message chunking support (deprecated).
+ CASTV2_1_2 = 2; // reworked message chunking.
+ CASTV2_1_3 = 3; // binary payload over utf8.
+ }
required ProtocolVersion protocol_version = 1;
// source and destination ids identify the origin and destination of the
@@ -49,6 +58,20 @@ message CastMessage {
// will always be set.
optional string payload_utf8 = 6;
optional bytes payload_binary = 7;
+
+ // --- Begin new 1.1 fields.
+
+ // Flag indicating whether there are more chunks to follow for this message.
+ // If the flag is false or is not present, then this is the last (or only)
+ // chunk of the message.
+ optional bool continued = 8;
+
+ // If this is a chunk of a larger message, and the remaining length of the
+ // message payload (the sum of the lengths of the payloads of the remaining
+ // chunks) is known, this field will indicate that length. For a given
+ // chunked message, this field should either be present in all of the chunks,
+ // or in none of them.
+ optional uint32 remaining_length = 9;
}
enum SignatureAlgorithm {
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h
new file mode 100644
index 00000000000..5373620bf03
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/fake_cast_socket.h
@@ -0,0 +1,108 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_
+#define CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_
+
+#include <memory>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "gmock/gmock.h"
+#include "platform/test/mock_tls_connection.h"
+
+namespace openscreen {
+namespace cast {
+
+class MockCastSocketClient final : public CastSocket::Client {
+ public:
+ ~MockCastSocketClient() override = default;
+
+ MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override));
+ MOCK_METHOD(void,
+ OnMessage,
+ (CastSocket * socket, ::cast::channel::CastMessage message),
+ (override));
+};
+
+struct FakeCastSocket {
+ FakeCastSocket()
+ : FakeCastSocket({{10, 0, 1, 7}, 1234}, {{10, 0, 1, 9}, 4321}) {}
+ FakeCastSocket(const IPEndpoint& local_endpoint,
+ const IPEndpoint& remote_endpoint)
+ : local_endpoint(local_endpoint),
+ remote_endpoint(remote_endpoint),
+ moved_connection(std::make_unique<MockTlsConnection>(local_endpoint,
+ remote_endpoint)),
+ connection(moved_connection.get()),
+ socket(std::move(moved_connection), &mock_client) {}
+
+ IPEndpoint local_endpoint;
+ IPEndpoint remote_endpoint;
+ std::unique_ptr<MockTlsConnection> moved_connection;
+ MockTlsConnection* connection;
+ MockCastSocketClient mock_client;
+ CastSocket socket;
+};
+
+// Two FakeCastSockets that are piped together via their MockTlsConnection
+// read/write methods. Calling SendMessage on |socket| will result in an
+// OnMessage callback on |mock_peer_client| and vice versa for |peer_socket| and
+// |mock_client|.
+struct FakeCastSocketPair {
+ FakeCastSocketPair()
+ : FakeCastSocketPair({{10, 0, 1, 7}, 1234}, {{10, 0, 1, 9}, 4321}) {}
+
+ FakeCastSocketPair(const IPEndpoint& local_endpoint,
+ const IPEndpoint& remote_endpoint)
+ : local_endpoint(local_endpoint), remote_endpoint(remote_endpoint) {
+ using ::testing::_;
+ using ::testing::Invoke;
+
+ auto moved_connection =
+ std::make_unique<::testing::NiceMock<MockTlsConnection>>(
+ local_endpoint, remote_endpoint);
+ connection = moved_connection.get();
+ socket =
+ std::make_unique<CastSocket>(std::move(moved_connection), &mock_client);
+
+ auto moved_peer = std::make_unique<::testing::NiceMock<MockTlsConnection>>(
+ remote_endpoint, local_endpoint);
+ peer_connection = moved_peer.get();
+ peer_socket =
+ std::make_unique<CastSocket>(std::move(moved_peer), &mock_peer_client);
+
+ ON_CALL(*connection, Send(_, _))
+ .WillByDefault(Invoke([this](const void* data, size_t len) {
+ peer_connection->OnRead(std::vector<uint8_t>(
+ reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ return true;
+ }));
+ ON_CALL(*peer_connection, Send(_, _))
+ .WillByDefault(Invoke([this](const void* data, size_t len) {
+ connection->OnRead(std::vector<uint8_t>(
+ reinterpret_cast<const uint8_t*>(data),
+ reinterpret_cast<const uint8_t*>(data) + len));
+ return true;
+ }));
+ }
+ ~FakeCastSocketPair() = default;
+
+ IPEndpoint local_endpoint;
+ IPEndpoint remote_endpoint;
+
+ ::testing::NiceMock<MockTlsConnection>* connection;
+ MockCastSocketClient mock_client;
+ std::unique_ptr<CastSocket> socket;
+
+ ::testing::NiceMock<MockTlsConnection>* peer_connection;
+ MockCastSocketClient mock_peer_client;
+ std::unique_ptr<CastSocket> peer_socket;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CHANNEL_TESTING_FAKE_CAST_SOCKET_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h
new file mode 100644
index 00000000000..48abefdc90f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_cast_message_handler.h
@@ -0,0 +1,28 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_
+#define CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_
+
+#include "cast/common/channel/cast_message_handler.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "gmock/gmock.h"
+
+namespace openscreen {
+namespace cast {
+
+class MockCastMessageHandler final : public CastMessageHandler {
+ public:
+ MOCK_METHOD(void,
+ OnMessage,
+ (VirtualConnectionRouter * router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message),
+ (override));
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CHANNEL_TESTING_MOCK_CAST_MESSAGE_HANDLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h
new file mode 100644
index 00000000000..60544b75192
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/channel/testing/mock_socket_error_handler.h
@@ -0,0 +1,25 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_
+#define CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_
+
+#include "cast/common/channel/virtual_connection_router.h"
+#include "gmock/gmock.h"
+#include "platform/base/error.h"
+
+namespace openscreen {
+namespace cast {
+
+class MockSocketErrorHandler
+ : public VirtualConnectionRouter::SocketErrorHandler {
+ public:
+ MOCK_METHOD(void, OnClose, (CastSocket * socket), (override));
+ MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override));
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_CHANNEL_TESTING_MOCK_SOCKET_ERROR_HANDLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h
index 0117bf51598..04f3ba06ab4 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection.h
@@ -9,8 +9,8 @@
#include <cstdint>
#include <string>
+namespace openscreen {
namespace cast {
-namespace channel {
// Transport system on top of CastSocket that allows routing messages over a
// single socket to different virtual endpoints (e.g. system messages vs.
@@ -35,16 +35,23 @@ struct VirtualConnection {
// - Receiver app can only send broadcast messages over an invisible
// connection.
kInvisible,
+
+ kMinValue = kStrong,
+ kMaxValue = kInvisible,
};
// Cast V2 protocol version constants. Must be in sync with
// proto/cast_channel.proto.
enum class ProtocolVersion {
kV2_1_0,
+ kV2_1_1,
+ kV2_1_2,
+ kV2_1_3,
};
enum CloseReason {
kUnknown,
+ kFirstReason = kUnknown,
// Underlying socket has been closed by peer. This happens when Cast sender
// closed transport connection normally without graceful virtual connection
@@ -68,6 +75,7 @@ struct VirtualConnection {
// The virtual connection has been closed by the peer gracefully.
kClosedByPeer,
+ kLastReason = kClosedByPeer,
};
struct AssociatedData {
@@ -91,7 +99,7 @@ struct VirtualConnection {
// app on the device.
std::string local_id;
std::string peer_id;
- uint32_t socket_id;
+ int socket_id;
};
inline bool operator==(const VirtualConnection& a, const VirtualConnection& b) {
@@ -103,7 +111,7 @@ inline bool operator!=(const VirtualConnection& a, const VirtualConnection& b) {
return !(a == b);
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc
index 8bac25e203a..86e06706a55 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.cc
@@ -6,8 +6,8 @@
#include <type_traits>
+namespace openscreen {
namespace cast {
-namespace channel {
VirtualConnectionManager::VirtualConnectionManager() = default;
@@ -15,7 +15,7 @@ VirtualConnectionManager::~VirtualConnectionManager() = default;
void VirtualConnectionManager::AddConnection(
VirtualConnection virtual_connection,
- VirtualConnection::AssociatedData&& associated_data) {
+ VirtualConnection::AssociatedData associated_data) {
auto& socket_map = connections_[virtual_connection.socket_id];
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
auto it = std::find_if(
@@ -81,7 +81,7 @@ size_t VirtualConnectionManager::RemoveConnectionsByLocalId(
}
size_t VirtualConnectionManager::RemoveConnectionsBySocketId(
- uint32_t socket_id,
+ int socket_id,
VirtualConnection::CloseReason reason) {
auto entry = connections_.find(socket_id);
if (entry == connections_.end()) {
@@ -115,5 +115,5 @@ VirtualConnectionManager::GetConnectionData(
return absl::nullopt;
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h
index 5bb5fdb1a39..902d2a969b0 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager.h
@@ -12,8 +12,8 @@
#include "absl/types/optional.h"
#include "cast/common/channel/virtual_connection.h"
+namespace openscreen {
namespace cast {
-namespace channel {
// Maintains a collection of open VirtualConnections and associated data.
class VirtualConnectionManager {
@@ -22,7 +22,7 @@ class VirtualConnectionManager {
~VirtualConnectionManager();
void AddConnection(VirtualConnection virtual_connection,
- VirtualConnection::AssociatedData&& associated_data);
+ VirtualConnection::AssociatedData associated_data);
// Returns true if a connection matching |virtual_connection| was found and
// removed.
@@ -32,7 +32,7 @@ class VirtualConnectionManager {
// Returns the number of connections removed.
size_t RemoveConnectionsByLocalId(const std::string& local_id,
VirtualConnection::CloseReason reason);
- size_t RemoveConnectionsBySocketId(uint32_t socket_id,
+ size_t RemoveConnectionsBySocketId(int socket_id,
VirtualConnection::CloseReason reason);
// Returns the AssociatedData for |virtual_connection| if a connection exists,
@@ -50,12 +50,12 @@ class VirtualConnectionManager {
VirtualConnection::AssociatedData data;
};
- std::map<uint32_t /* socket_id */,
+ std::map<int /* socket_id */,
std::multimap<std::string /* local_id */, VCTail>>
connections_;
};
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_MANAGER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc
index 8edf105adac..963fcac728f 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_manager_unittest.cc
@@ -8,13 +8,22 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace channel {
namespace {
-static_assert(CastMessage_ProtocolVersion_CASTV2_1_0 ==
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 ==
static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0),
"V2 1.0 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1),
+ "V2 1.1 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2),
+ "V2 1.2 constants must be equal");
+static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 ==
+ static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3),
+ "V2 1.3 constants must be equal");
using ::testing::_;
using ::testing::Invoke;
@@ -129,5 +138,5 @@ TEST_F(VirtualConnectionManagerTest, RemoveConnectionsByIds) {
EXPECT_FALSE(manager_.GetConnectionData(vc3_));
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc
index 9a5d12832f0..7ceace603d9 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.cc
@@ -11,8 +11,10 @@
#include "cast/common/channel/virtual_connection_manager.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+
+using ::cast::channel::CastMessage;
VirtualConnectionRouter::VirtualConnectionRouter(
VirtualConnectionManager* vc_manager)
@@ -35,14 +37,16 @@ bool VirtualConnectionRouter::RemoveHandlerForLocalId(
void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
std::unique_ptr<CastSocket> socket) {
- uint32_t id = socket->socket_id();
+ int id = socket->socket_id();
socket->SetClient(this);
sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
}
-void VirtualConnectionRouter::CloseSocket(uint32_t id) {
+void VirtualConnectionRouter::CloseSocket(int id) {
auto it = sockets_.find(id);
if (it != sockets_.end()) {
+ vc_manager_->RemoveConnectionsBySocketId(
+ id, VirtualConnection::kTransportClosed);
std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
@@ -50,26 +54,27 @@ void VirtualConnectionRouter::CloseSocket(uint32_t id) {
}
}
-Error VirtualConnectionRouter::SendMessage(VirtualConnection vconn,
- CastMessage&& message) {
+Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn,
+ CastMessage message) {
// TODO(btolsch): Check for broadcast message.
if (!IsTransportNamespace(message.namespace_()) &&
- !vc_manager_->GetConnectionData(vconn)) {
- return Error::Code::kUnknownError;
+ !vc_manager_->GetConnectionData(virtual_conn)) {
+ return Error::Code::kNoActiveConnection;
}
- auto it = sockets_.find(vconn.socket_id);
+ auto it = sockets_.find(virtual_conn.socket_id);
if (it == sockets_.end()) {
- return Error::Code::kUnknownError;
+ return Error::Code::kItemNotFound;
}
- message.set_source_id(std::move(vconn.local_id));
- message.set_destination_id(std::move(vconn.peer_id));
+ message.set_source_id(std::move(virtual_conn.local_id));
+ message.set_destination_id(std::move(virtual_conn.peer_id));
return it->second.socket->SendMessage(message);
}
void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
- uint32_t id = socket->socket_id();
+ int id = socket->socket_id();
auto it = sockets_.find(id);
if (it != sockets_.end()) {
+ vc_manager_->RemoveConnectionsBySocketId(id, VirtualConnection::kUnknown);
std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
@@ -80,10 +85,10 @@ void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
void VirtualConnectionRouter::OnMessage(CastSocket* socket,
CastMessage message) {
// TODO(btolsch): Check for broadcast message.
- VirtualConnection vconn{message.destination_id(), message.source_id(),
- socket->socket_id()};
+ VirtualConnection virtual_conn{message.destination_id(), message.source_id(),
+ socket->socket_id()};
if (!IsTransportNamespace(message.namespace_()) &&
- !vc_manager_->GetConnectionData(vconn)) {
+ !vc_manager_->GetConnectionData(virtual_conn)) {
return;
}
const std::string& local_id = message.destination_id();
@@ -93,5 +98,5 @@ void VirtualConnectionRouter::OnMessage(CastSocket* socket,
}
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h
index 8ec643bd779..375c24ff50b 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router.h
@@ -11,9 +11,10 @@
#include <string>
#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+namespace openscreen {
namespace cast {
-namespace channel {
class CastMessageHandler;
struct VirtualConnection;
@@ -56,13 +57,15 @@ class VirtualConnectionRouter final : public CastSocket::Client {
// |error_handler| must live until either its OnError or OnClose is called.
void TakeSocket(SocketErrorHandler* error_handler,
std::unique_ptr<CastSocket> socket);
- void CloseSocket(uint32_t id);
+ void CloseSocket(int id);
- Error SendMessage(VirtualConnection vconn, CastMessage&& message);
+ Error SendMessage(VirtualConnection virtual_conn,
+ ::cast::channel::CastMessage message);
// CastSocket::Client overrides.
void OnError(CastSocket* socket, Error error) override;
- void OnMessage(CastSocket* socket, CastMessage message) override;
+ void OnMessage(CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
private:
struct SocketWithHandler {
@@ -71,11 +74,11 @@ class VirtualConnectionRouter final : public CastSocket::Client {
};
VirtualConnectionManager* const vc_manager_;
- std::map<uint32_t, SocketWithHandler> sockets_;
+ std::map<int, SocketWithHandler> sockets_;
std::map<std::string /* local_id */, CastMessageHandler*> endpoints_;
};
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_COMMON_CHANNEL_VIRTUAL_CONNECTION_ROUTER_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc
index 12d82f9096d..57e9fa1b883 100644
--- a/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/common/channel/virtual_connection_router_unittest.cc
@@ -6,30 +6,24 @@
#include "cast/common/channel/cast_socket.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
-#include "cast/common/channel/test/fake_cast_socket.h"
-#include "cast/common/channel/test/mock_cast_message_handler.h"
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_cast_message_handler.h"
+#include "cast/common/channel/testing/mock_socket_error_handler.h"
#include "cast/common/channel/virtual_connection_manager.h"
-#include "gmock/gmock.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace channel {
namespace {
+using ::cast::channel::CastMessage;
using ::testing::_;
using ::testing::Invoke;
-class MockSocketErrorHandler
- : public VirtualConnectionRouter::SocketErrorHandler {
- public:
- MOCK_METHOD(void, OnClose, (CastSocket * socket), (override));
- MOCK_METHOD(void, OnError, (CastSocket * socket, Error error), (override));
-};
-
class VirtualConnectionRouterTest : public ::testing::Test {
public:
void SetUp() override {
- socket_id_ = fake_cast_socket_pair_.socket->socket_id();
+ socket_ = fake_cast_socket_pair_.socket.get();
router_.TakeSocket(&mock_error_handler_,
std::move(fake_cast_socket_pair_.socket));
}
@@ -38,7 +32,7 @@ class VirtualConnectionRouterTest : public ::testing::Test {
CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; }
FakeCastSocketPair fake_cast_socket_pair_;
- uint32_t socket_id_;
+ CastSocket* socket_;
MockSocketErrorHandler mock_error_handler_;
MockCastMessageHandler mock_message_handler_;
@@ -52,53 +46,59 @@ class VirtualConnectionRouterTest : public ::testing::Test {
TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_);
manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-9873", socket_id_}, {});
+ VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()},
+ {});
CastMessage message;
- message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
message.set_namespace_("zrqvn");
message.set_source_id("sender-9873");
message.set_destination_id("receiver-1234");
message.set_payload_type(CastMessage::STRING);
message.set_payload_utf8("cnlybnq");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _));
- peer_socket().SendMessage(message);
+ EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
+ EXPECT_TRUE(peer_socket().SendMessage(message).ok());
- EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _));
- peer_socket().SendMessage(message);
+ EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
+ EXPECT_TRUE(peer_socket().SendMessage(message).ok());
message.set_destination_id("receiver-4321");
EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0);
- peer_socket().SendMessage(message);
+ EXPECT_TRUE(peer_socket().SendMessage(message).ok());
}
TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler_);
manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-9873", socket_id_}, {});
+ VirtualConnection{"receiver-1234", "sender-9873", socket_->socket_id()},
+ {});
CastMessage message;
- message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
message.set_namespace_("zrqvn");
message.set_source_id("sender-9873");
message.set_destination_id("receiver-1234");
message.set_payload_type(CastMessage::STRING);
message.set_payload_utf8("cnlybnq");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _));
- peer_socket().SendMessage(message);
+ EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _));
+ EXPECT_TRUE(peer_socket().SendMessage(message).ok());
router_.RemoveHandlerForLocalId("receiver-1234");
- EXPECT_CALL(mock_message_handler_, OnMessage(_, _, _)).Times(0);
- peer_socket().SendMessage(message);
+ EXPECT_CALL(mock_message_handler_, OnMessage(_, socket_, _)).Times(0);
+ EXPECT_TRUE(peer_socket().SendMessage(message).ok());
}
TEST_F(VirtualConnectionRouterTest, SendMessage) {
manager_.AddConnection(
- VirtualConnection{"receiver-1234", "sender-4321", socket_id_}, {});
+ VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
+ {});
CastMessage message;
- message.set_protocol_version(CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
message.set_namespace_("zrqvn");
message.set_source_id("receiver-1234");
message.set_destination_id("sender-4321");
@@ -109,13 +109,25 @@ TEST_F(VirtualConnectionRouterTest, SendMessage) {
EXPECT_EQ(message.namespace_(), "zrqvn");
EXPECT_EQ(message.source_id(), "receiver-1234");
EXPECT_EQ(message.destination_id(), "sender-4321");
- ASSERT_EQ(message.payload_type(), CastMessage_PayloadType_STRING);
+ ASSERT_EQ(message.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_STRING);
EXPECT_EQ(message.payload_utf8(), "cnlybnq");
}));
router_.SendMessage(
- VirtualConnection{"receiver-1234", "sender-4321", socket_id_},
+ VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
std::move(message));
}
-} // namespace channel
+TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
+ manager_.AddConnection(
+ VirtualConnection{"receiver-1234", "sender-4321", socket_->socket_id()},
+ {});
+
+ int id = socket_->socket_id();
+ router_.CloseSocket(id);
+ EXPECT_FALSE(manager_.GetConnectionData(
+ VirtualConnection{"receiver-1234", "sender-4321", id}));
+}
+
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS
new file mode 100644
index 00000000000..75ecfdce80a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/DEPS
@@ -0,0 +1,3 @@
+include_rules = [
+ '+platform/impl',
+]
diff --git a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc
new file mode 100644
index 00000000000..2a2cb62ad54
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc
@@ -0,0 +1,578 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <atomic>
+#include <functional>
+#include <map>
+
+#include "cast/common/public/service_info.h"
+#include "discovery/common/config.h"
+#include "discovery/common/reporting_client.h"
+#include "discovery/public/dns_sd_service_factory.h"
+#include "discovery/public/dns_sd_service_publisher.h"
+#include "discovery/public/dns_sd_service_watcher.h"
+#include "gtest/gtest.h"
+#include "platform/api/logging.h"
+#include "platform/api/udp_socket.h"
+#include "platform/base/interface_info.h"
+#include "platform/impl/logging.h"
+#include "platform/impl/network_interface.h"
+#include "platform/impl/platform_client_posix.h"
+#include "platform/impl/task_runner.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// Total wait time = 4 seconds.
+constexpr std::chrono::milliseconds kWaitLoopSleepTime =
+ std::chrono::milliseconds(500);
+constexpr int kMaxWaitLoopIterations = 8;
+
+// Total wait time = 2.5 seconds.
+// NOTE: This must be less than the above wait time.
+constexpr std::chrono::milliseconds kCheckLoopSleepTime =
+ std::chrono::milliseconds(100);
+constexpr int kMaxCheckLoopIterations = 25;
+
+} // namespace
+
+// Publishes new service instances.
+class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> {
+ public:
+ Publisher(discovery::DnsSdService* service)
+ : DnsSdServicePublisher<ServiceInfo>(service,
+ kCastV2ServiceId,
+ ServiceInfoToDnsSdRecord) {
+ OSP_LOG << "Initializing Publisher...\n";
+ }
+
+ ~Publisher() override = default;
+
+ bool IsInstanceIdClaimed(const std::string& requested_id) {
+ auto it =
+ std::find(instance_ids_.begin(), instance_ids_.end(), requested_id);
+ return it != instance_ids_.end();
+ }
+
+ private:
+ // DnsSdPublisher::Client overrides.
+ void OnInstanceClaimed(const std::string& requested_id) override {
+ instance_ids_.push_back(requested_id);
+ }
+
+ std::vector<std::string> instance_ids_;
+};
+
+// Receives incoming services and outputs their results to stdout.
+class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> {
+ public:
+ Receiver(discovery::DnsSdService* service)
+ : discovery::DnsSdServiceWatcher<ServiceInfo>(
+ service,
+ kCastV2ServiceId,
+ DnsSdRecordToServiceInfo,
+ [this](
+ std::vector<std::reference_wrapper<const ServiceInfo>> infos) {
+ ProcessResults(std::move(infos));
+ }) {
+ OSP_LOG << "Initializing Receiver...";
+ }
+
+ bool IsServiceFound(const ServiceInfo& check_service) {
+ return std::find_if(service_infos_.begin(), service_infos_.end(),
+ [&check_service](const ServiceInfo& info) {
+ return info.friendly_name ==
+ check_service.friendly_name;
+ }) != service_infos_.end();
+ }
+
+ void EraseReceivedServices() { service_infos_.clear(); }
+
+ private:
+ void ProcessResults(
+ std::vector<std::reference_wrapper<const ServiceInfo>> infos) {
+ service_infos_.clear();
+ for (const ServiceInfo& info : infos) {
+ service_infos_.push_back(info);
+ }
+ }
+
+ std::vector<ServiceInfo> service_infos_;
+};
+
+class FailOnErrorReporting : public discovery::ReportingClient {
+ void OnFatalError(Error error) override {
+ // TODO(rwkeane): Change this to OSP_NOTREACHED() pending resolution of
+ // socket initialization issue.
+ OSP_LOG << "Fatal error received: '" << error << "'";
+ }
+
+ void OnRecoverableError(Error error) override {
+ // Pending resolution of openscreen:105, logging recoverable errors is
+ // disabled, as this will end up polluting the output with logs related to
+ // mDNS messages received from non-loopback network interfaces over which
+ // we have no control.
+ }
+};
+
+discovery::Config GetConfigSettings() {
+ discovery::Config config;
+
+ // Get the loopback interface to run on.
+ absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
+ OSP_DCHECK(loopback.has_value());
+ config.interface = loopback.value();
+
+ return config;
+}
+
+class DiscoveryE2ETest : public testing::Test {
+ public:
+ DiscoveryE2ETest() {
+ // Sleep to let any packets clear off the network before further tests.
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ // Set log level so info logs go to stdout.
+ SetLogLevel(LogLevel::kInfo);
+
+ PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50});
+ task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
+ }
+
+ ~DiscoveryE2ETest() {
+ OSP_LOG << "TEST COMPLETE!";
+ dnssd_service_.reset();
+ PlatformClientPosix::ShutDown();
+ }
+
+ protected:
+ ServiceInfo GetInfoV4() {
+ ServiceInfo hosted_service;
+ hosted_service.v4_address = IPAddress{10, 0, 0, 1};
+ hosted_service.port = 25252;
+ hosted_service.unique_id = "id1";
+ hosted_service.model_name = "openscreen-ModelV4";
+ hosted_service.friendly_name = "DemoV4!";
+ return hosted_service;
+ }
+
+ ServiceInfo GetInfoV6() {
+ ServiceInfo hosted_service;
+ hosted_service.v6_address = IPAddress{1, 2, 3, 4, 5, 6, 7, 8};
+ hosted_service.port = 25252;
+ hosted_service.unique_id = "id2";
+ hosted_service.model_name = "openscreen-ModelV6";
+ hosted_service.friendly_name = "DemoV6!";
+ return hosted_service;
+ }
+
+ ServiceInfo GetInfoV4V6() {
+ ServiceInfo hosted_service;
+ hosted_service.v4_address = IPAddress{10, 0, 0, 2};
+ hosted_service.v6_address = IPAddress{1, 2, 3, 4, 5, 6, 7, 9};
+ hosted_service.port = 25254;
+ hosted_service.unique_id = "id3";
+ hosted_service.model_name = "openscreen-ModelV4andV6";
+ hosted_service.friendly_name = "DemoV4andV6!";
+ return hosted_service;
+ }
+
+ void SetUpService(const discovery::Config& config) {
+ OSP_DCHECK(!dnssd_service_.get());
+ dnssd_service_ =
+ discovery::CreateDnsSdService(task_runner_, &reporting_client_, config);
+ receiver_ = std::make_unique<Receiver>(dnssd_service_.get());
+ publisher_ = std::make_unique<Publisher>(dnssd_service_.get());
+ }
+
+ void StartDiscovery() {
+ OSP_DCHECK(dnssd_service_.get());
+ task_runner_->PostTask([this]() { receiver_->StartDiscovery(); });
+ }
+
+ template <typename... RecordTypes>
+ void UpdateRecords(RecordTypes... records) {
+ OSP_DCHECK(dnssd_service_.get());
+ OSP_DCHECK(publisher_.get());
+
+ std::vector<ServiceInfo> record_set{std::move(records)...};
+ for (ServiceInfo& record : record_set) {
+ task_runner_->PostTask([this, r = std::move(record)]() {
+ auto error = publisher_->UpdateRegistration(r);
+ OSP_DCHECK(error.ok()) << "\tFailed to update service instance '"
+ << r.friendly_name << "': " << error << "!";
+ });
+ }
+ }
+
+ template <typename... RecordTypes>
+ void PublishRecords(RecordTypes... records) {
+ OSP_DCHECK(dnssd_service_.get());
+ OSP_DCHECK(publisher_.get());
+
+ std::vector<ServiceInfo> record_set{std::move(records)...};
+ for (ServiceInfo& record : record_set) {
+ task_runner_->PostTask([this, r = std::move(record)]() {
+ auto error = publisher_->Register(r);
+ OSP_DCHECK(error.ok()) << "\tFailed to publish service instance '"
+ << r.friendly_name << "': " << error << "!";
+ });
+ }
+ }
+
+ template <typename... AtomicBoolPtrs>
+ void WaitUntilSeen(bool should_be_seen, AtomicBoolPtrs... bools) {
+ OSP_DCHECK(dnssd_service_.get());
+ std::vector<std::atomic_bool*> atomic_bools{bools...};
+
+ int waiting_on = atomic_bools.size();
+ for (int i = 0; i < kMaxWaitLoopIterations; i++) {
+ waiting_on = atomic_bools.size();
+ for (std::atomic_bool* atomic : atomic_bools) {
+ if (*atomic) {
+ OSP_DCHECK(should_be_seen) << "Found service instance!";
+ waiting_on--;
+ }
+ }
+
+ if (waiting_on) {
+ OSP_LOG << "\tWaiting on " << waiting_on << "...";
+ std::this_thread::sleep_for(kWaitLoopSleepTime);
+ continue;
+ }
+ return;
+ }
+ OSP_DCHECK(!should_be_seen)
+ << "Could not find " << waiting_on << " service instances!";
+ }
+
+ void CheckForClaimedIds(ServiceInfo service_info,
+ std::atomic_bool* has_been_seen) {
+ OSP_DCHECK(dnssd_service_.get());
+ task_runner_->PostTask(
+ [this, info = std::move(service_info), has_been_seen]() mutable {
+ CheckForClaimedIds(std::move(info), has_been_seen, 0);
+ });
+ }
+
+ void CheckForPublishedService(ServiceInfo service_info,
+ std::atomic_bool* has_been_seen) {
+ OSP_DCHECK(dnssd_service_.get());
+ task_runner_->PostTask(
+ [this, info = std::move(service_info), has_been_seen]() mutable {
+ CheckForPublishedService(std::move(info), has_been_seen, 0, true);
+ });
+ }
+
+ void CheckNotPublishedService(ServiceInfo service_info,
+ std::atomic_bool* has_been_seen) {
+ OSP_DCHECK(dnssd_service_.get());
+ task_runner_->PostTask(
+ [this, info = std::move(service_info), has_been_seen]() mutable {
+ CheckForPublishedService(std::move(info), has_been_seen, 0, false);
+ });
+ }
+ TaskRunner* task_runner_;
+ FailOnErrorReporting reporting_client_;
+ SerialDeletePtr<discovery::DnsSdService> dnssd_service_;
+ std::unique_ptr<Receiver> receiver_;
+ std::unique_ptr<Publisher> publisher_;
+
+ private:
+ void CheckForClaimedIds(ServiceInfo service_info,
+ std::atomic_bool* has_been_seen,
+ int attempts) {
+ if (publisher_->IsInstanceIdClaimed(service_info.GetInstanceId())) {
+ // TODO(crbug.com/openscreen/110): Log the published service instance.
+ *has_been_seen = true;
+ return;
+ }
+
+ if (attempts++ > kMaxCheckLoopIterations) {
+ OSP_NOTREACHED() << "Service " << service_info.friendly_name
+ << " publication failed.";
+ }
+ task_runner_->PostTaskWithDelay(
+ [this, info = std::move(service_info), has_been_seen,
+ attempts]() mutable {
+ CheckForClaimedIds(std::move(info), has_been_seen, attempts);
+ },
+ kCheckLoopSleepTime);
+ }
+
+ void CheckForPublishedService(ServiceInfo service_info,
+ std::atomic_bool* has_been_seen,
+ int attempts,
+ bool expect_to_be_present) {
+ if (!receiver_->IsServiceFound(service_info)) {
+ if (attempts++ > kMaxCheckLoopIterations) {
+ OSP_DCHECK(!expect_to_be_present)
+ << "Service " << service_info.friendly_name << " discovery failed.";
+ return;
+ }
+ task_runner_->PostTaskWithDelay(
+ [this, info = std::move(service_info), has_been_seen, attempts,
+ expect_to_be_present]() mutable {
+ CheckForPublishedService(std::move(info), has_been_seen, attempts,
+ expect_to_be_present);
+ },
+ kCheckLoopSleepTime);
+ } else if (expect_to_be_present) {
+ // TODO(crbug.com/openscreen/110): Log the discovered service instance.
+ *has_been_seen = true;
+ } else {
+ OSP_NOTREACHED() << "Found instance '" << service_info.friendly_name
+ << "'!";
+ }
+ }
+};
+
+// The below runs an E2E tests. These test requires no user interaction and is
+// intended to perform a set series of actions to validate that discovery is
+// functioning as intended.
+//
+// Known issues:
+// - The ipv6 socket in discovery/mdns/service_impl.cc fails to bind to an ipv6
+// address on the loopback interface. Investigating this issue is pending
+// resolution of bug
+// https://bugs.chromium.org/p/openscreen/issues/detail?id=105.
+//
+// In this test, the following operations are performed:
+// 1) Start up the Cast platform for a posix system.
+// 2) Publish 3 CastV2 service instances to the loopback interface using mDNS,
+// with record announcement disabled.
+// 3) Wait for the probing phase to successfully complete.
+// 4) Query for records published over the loopback interface, and validate that
+// all 3 previously published services are discovered.
+TEST_F(DiscoveryE2ETest, ValidateQueryFlow) {
+ // Set up demo infra.
+ auto discovery_config = GetConfigSettings();
+ discovery_config.new_record_announcement_count = 0;
+ SetUpService(discovery_config);
+
+ auto v4 = GetInfoV4();
+ auto v6 = GetInfoV6();
+ auto multi_address = GetInfoV4V6();
+
+ // Start discovery and publication.
+ StartDiscovery();
+ PublishRecords(v4, v6, multi_address);
+
+ // Wait until all probe phases complete and all instance ids are claimed. At
+ // this point, all records should be published.
+ OSP_LOG << "Service publication in progress...";
+ std::atomic_bool v4_found{false};
+ std::atomic_bool v6_found{false};
+ std::atomic_bool multi_address_found{false};
+ CheckForClaimedIds(v4, &v4_found);
+ CheckForClaimedIds(v6, &v6_found);
+ CheckForClaimedIds(multi_address, &multi_address_found);
+ WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found);
+ OSP_LOG << "\tAll services successfully published!\n";
+
+ // Make sure all services are found through discovery.
+ OSP_LOG << "Service discovery in progress...";
+ v4_found = false;
+ v6_found = false;
+ multi_address_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ CheckForPublishedService(v6, &v6_found);
+ CheckForPublishedService(multi_address, &multi_address_found);
+ WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found);
+}
+
+// In this test, the following operations are performed:
+// 1) Start up the Cast platform for a posix system.
+// 2) Start service discovery and new queries, with no query messages being
+// sent.
+// 3) Publish 3 CastV2 service instances to the loopback interface using mDNS,
+// with record announcement enabled.
+// 4) Ensure the correct records were published over the loopback interface.
+// 5) De-register all services.
+// 6) Ensure that goodbye records are received for all service instances.
+TEST_F(DiscoveryE2ETest, ValidateAnnouncementFlow) {
+ // Set up demo infra.
+ auto discovery_config = GetConfigSettings();
+ discovery_config.new_query_announcement_count = 0;
+ SetUpService(discovery_config);
+
+ auto v4 = GetInfoV4();
+ auto v6 = GetInfoV6();
+ auto multi_address = GetInfoV4V6();
+
+ // Start discovery and publication.
+ StartDiscovery();
+ PublishRecords(v4, v6, multi_address);
+
+ // Wait until all probe phases complete and all instance ids are claimed. At
+ // this point, all records should be published.
+ OSP_LOG << "Service publication in progress...";
+ std::atomic_bool v4_found{false};
+ std::atomic_bool v6_found{false};
+ std::atomic_bool multi_address_found{false};
+ CheckForClaimedIds(v4, &v4_found);
+ CheckForClaimedIds(v6, &v6_found);
+ CheckForClaimedIds(multi_address, &multi_address_found);
+ WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found);
+ OSP_LOG << "\tAll services successfully published and announced!\n";
+
+ // Make sure all services are found through discovery.
+ OSP_LOG << "Service discovery in progress...";
+ v4_found = false;
+ v6_found = false;
+ multi_address_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ CheckForPublishedService(v6, &v6_found);
+ CheckForPublishedService(multi_address, &multi_address_found);
+ WaitUntilSeen(true, &v4_found, &v6_found, &multi_address_found);
+}
+
+// In this test, the following operations are performed:
+// 1) Start up the Cast platform for a posix system.
+// 2) Publish one service and ensure it is NOT received.
+// 3) Start service discovery and new queries.
+// 4) Ensure above published service IS received.
+// 5) Stop the started query.
+// 6) Update a service, and ensure that no callback is received.
+// 7) Restart the query and ensure that only the expected callbacks are
+// received.
+TEST_F(DiscoveryE2ETest, ValidateRecordsOnlyReceivedWhenQueryRunning) {
+ // Set up demo infra.
+ auto discovery_config = GetConfigSettings();
+ discovery_config.new_record_announcement_count = 1;
+ SetUpService(discovery_config);
+
+ auto v4 = GetInfoV4();
+
+ // Start discovery and publication.
+ PublishRecords(v4);
+
+ // Wait until all probe phases complete and all instance ids are claimed. At
+ // this point, all records should be published.
+ OSP_LOG << "Service publication in progress...";
+ std::atomic_bool v4_found{false};
+ CheckForClaimedIds(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // And ensure stopped discovery does not find the records.
+ OSP_LOG << "Validating no service discovery occurs when discovery stopped...";
+ v4_found = false;
+ CheckNotPublishedService(v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+
+ // Make sure all services are found through discovery.
+ StartDiscovery();
+ OSP_LOG << "Service discovery in progress...";
+ v4_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // Update discovery and ensure that the updated service is seen.
+ OSP_LOG << "Updating service and waiting for discovery...";
+ auto updated_v4 = v4;
+ updated_v4.friendly_name = "OtherName";
+ v4_found = false;
+ UpdateRecords(updated_v4);
+ CheckForPublishedService(updated_v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // And ensure the old service has been removed.
+ v4_found = false;
+ CheckNotPublishedService(v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+
+ // Stop discovery.
+ OSP_LOG << "Stopping discovery...";
+ task_runner_->PostTask([this]() { receiver_->StopDiscovery(); });
+
+ // Update discovery and ensure that the updated service is NOT seen.
+ OSP_LOG << "Updating service and validating the change isn't received...";
+ v4_found = false;
+ v4.friendly_name = "ThirdName";
+ UpdateRecords(v4);
+ CheckNotPublishedService(v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+
+ // Restart discovery and ensure that only the updated record is returned.
+ StartDiscovery();
+ OSP_LOG << "Service discovery in progress...";
+ v4_found = false;
+ CheckNotPublishedService(updated_v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+
+ v4_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+}
+
+// In this test, the following operations are performed:
+// 1) Start up the Cast platform for a posix system.
+// 2) Start service discovery and new queries.
+// 3) Publish one service and ensure it is received.
+// 4) Hard reset discovery
+// 5) Ensure the same service is discovered
+// 6) Soft reset the service, and ensure that a callback is received.
+TEST_F(DiscoveryE2ETest, ValidateRefreshFlow) {
+ // Set up demo infra.
+ // NOTE: This configuration assumes that packets cannot be lost over the
+ // loopback interface.
+ auto discovery_config = GetConfigSettings();
+ discovery_config.new_record_announcement_count = 0;
+ discovery_config.new_query_announcement_count = 2;
+ constexpr std::chrono::seconds kMaxQueryDuration{3};
+ SetUpService(discovery_config);
+
+ auto v4 = GetInfoV4();
+
+ // Start discovery and publication.
+ StartDiscovery();
+ PublishRecords(v4);
+
+ // Wait until all probe phases complete and all instance ids are claimed. At
+ // this point, all records should be published.
+ OSP_LOG << "Service publication in progress...";
+ std::atomic_bool v4_found{false};
+ CheckForClaimedIds(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // Make sure all services are found through discovery.
+ OSP_LOG << "Service discovery in progress...";
+ v4_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // Force refresh discovery, then ensure that the published service is
+ // re-discovered.
+ OSP_LOG << "Force refresh discovery...";
+ task_runner_->PostTask([this]() { receiver_->EraseReceivedServices(); });
+ std::this_thread::sleep_for(kMaxQueryDuration);
+ v4_found = false;
+ CheckNotPublishedService(v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+ task_runner_->PostTask([this]() { receiver_->ForceRefresh(); });
+
+ OSP_LOG << "Ensure that the published service is re-discovered...";
+ v4_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+
+ // Soft refresh discovery, then ensure that the published service is NOT
+ // re-discovered.
+ OSP_LOG << "Call DiscoverNow on discovery...";
+ task_runner_->PostTask([this]() { receiver_->EraseReceivedServices(); });
+ std::this_thread::sleep_for(kMaxQueryDuration);
+ v4_found = false;
+ CheckNotPublishedService(v4, &v4_found);
+ WaitUntilSeen(false, &v4_found);
+ task_runner_->PostTask([this]() { receiver_->DiscoverNow(); });
+
+ OSP_LOG << "Ensure that the published service is re-discovered...";
+ v4_found = false;
+ CheckForPublishedService(v4, &v4_found);
+ WaitUntilSeen(true, &v4_found);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/public/DEPS b/chromium/third_party/openscreen/src/cast/common/public/DEPS
index 7060c48007f..c098d4db36a 100644
--- a/chromium/third_party/openscreen/src/cast/common/public/DEPS
+++ b/chromium/third_party/openscreen/src/cast/common/public/DEPS
@@ -1,12 +1,8 @@
# -*- Mode: Python; -*-
include_rules = [
- # By default, openscreen implementation libraries should not be exposed
- # through public APIs.
- '-base',
- '-platform',
-
# Dependencies on the implementation are not allowed in public/.
'-cast/common',
- '+cast/common/public'
+ '+cast/common/public',
+ '+discovery/dnssd/public'
]
diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc
new file mode 100644
index 00000000000..a886f7d2c6b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc
@@ -0,0 +1,243 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/public/service_info.h"
+
+#include <cctype>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_replace.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// The mask for the set of supported capabilities in v2 of the Cast protocol.
+const uint16_t kCapabilitiesMask = 0x1F;
+
+// Maximum size for registered MDNS service instance names.
+const size_t kMaxDeviceNameSize = 63;
+
+// Maximum size for the device model prefix at start of MDNS service instance
+// names. Any model names that are larger than this size will be truncated.
+const size_t kMaxDeviceModelSize = 20;
+
+// Build the MDNS instance name for service. This will be the device model (up
+// to 20 bytes) appended with the virtual device ID (device UUID) and optionally
+// appended with extension at the end to resolve name conflicts. The total MDNS
+// service instance name is kept below 64 bytes so it can easily fit into a
+// single domain name label.
+//
+// NOTE: This value is based on what is currently done by Eureka, not what is
+// called out in the CastV2 spec. Eureka uses |model|-|uuid|, so the same
+// convention will be followed here. That being said, the Eureka receiver does
+// not use the instance ID in any way, so the specific calculation used should
+// not be important.
+std::string CalculateInstanceId(const ServiceInfo& info) {
+ // First set the device model, truncated to 20 bytes at most. Replace any
+ // whitespace characters (" ") with hyphens ("-") in the device model before
+ // truncation.
+ std::string instance_name =
+ absl::StrReplaceAll(info.model_name, {{" ", "-"}});
+ instance_name = std::string(instance_name, 0, kMaxDeviceModelSize);
+
+ // Append the virtual device ID to the instance name separated by a single
+ // '-' character if not empty. Strip all hyphens from the device ID prior
+ // to appending it.
+ std::string device_id = absl::StrReplaceAll(info.unique_id, {{"-", ""}});
+
+ if (!instance_name.empty()) {
+ instance_name.push_back('-');
+ }
+ instance_name.append(device_id);
+
+ return std::string(instance_name, 0, kMaxDeviceNameSize);
+}
+
+// NOTE: Eureka uses base::StringToUint64 which takes in a string and reads it
+// left to right, converts it to a number sequence, and uses the sequence to
+// calculate the resulting integer. This process assumes that the input is in
+// base 10. For example, ['1', '2', '3'] converts to 123.
+//
+// The below 2 functions re-create this logic for converting to and from this
+// encoding scheme.
+inline std::string EncodeIntegerString(uint64_t value) {
+ return std::to_string(value);
+}
+
+ErrorOr<uint64_t> DecodeIntegerString(const std::string& value) {
+ uint64_t result;
+ if (!absl::SimpleAtoi(value, &result)) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ return result;
+}
+
+// Attempts to parse the string present at the provided |key| in the TXT record
+// |txt|, placing the result into |result| on success and error into |error| on
+// failure.
+bool TryParseString(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ Error* error,
+ std::string* result) {
+ const ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key);
+ if (value.is_error()) {
+ *error = value.error();
+ return false;
+ }
+
+ const std::vector<uint8_t>& txt_value = value.value().get();
+ *result = std::string(txt_value.begin(), txt_value.end());
+ return true;
+}
+// Attempts to parse the uint8_t present at the provided |key| in the TXT record
+// |txt|, placing the result into |result| on success and error into |error| on
+// failure.
+bool TryParseInt(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ Error* error,
+ uint8_t* result) {
+ const ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key);
+ if (value.is_error()) {
+ *error = value.error();
+ return false;
+ }
+
+ const std::vector<uint8_t>& txt_value = value.value().get();
+ if (txt_value.size() != 1) {
+ *error = Error::Code::kParameterInvalid;
+ return false;
+ }
+
+ *result = txt_value[0];
+ return true;
+}
+
+// Simplifies logic below by changing error into an output parameter instead of
+// a return value.
+bool IsError(Error error, Error* result) {
+ if (error.ok()) {
+ return false;
+ } else {
+ *result = error;
+ return true;
+ }
+}
+
+} // namespace
+
+const std::string& ServiceInfo::GetInstanceId() const {
+ if (instance_id_ == std::string("")) {
+ instance_id_ = CalculateInstanceId(*this);
+ }
+
+ return instance_id_;
+}
+
+bool ServiceInfo::IsValid() const {
+ std::string instance_id = GetInstanceId();
+ if (!discovery::IsInstanceValid(instance_id)) {
+ return false;
+ }
+
+ const std::string capabilities_str = EncodeIntegerString(capabilities);
+ if (!discovery::DnsSdTxtRecord::IsValidTxtValue(kUniqueIdKey, unique_id) ||
+ !discovery::DnsSdTxtRecord::IsValidTxtValue(kVersionId,
+ protocol_version) ||
+ !discovery::DnsSdTxtRecord::IsValidTxtValue(kCapabilitiesId,
+ capabilities_str) ||
+ !discovery::DnsSdTxtRecord::IsValidTxtValue(kStatusId, status) ||
+ !discovery::DnsSdTxtRecord::IsValidTxtValue(kFriendlyNameId,
+ friendly_name) ||
+ !discovery::DnsSdTxtRecord::IsValidTxtValue(kModelNameId, model_name)) {
+ return false;
+ }
+
+ return port && (v4_address || v6_address);
+}
+
+discovery::DnsSdInstanceRecord ServiceInfoToDnsSdRecord(
+ const ServiceInfo& service) {
+ OSP_DCHECK(discovery::IsServiceValid(kCastV2ServiceId));
+ OSP_DCHECK(discovery::IsDomainValid(kCastV2DomainId));
+
+ std::string instance_id = service.GetInstanceId();
+ OSP_DCHECK(discovery::IsInstanceValid(instance_id));
+
+ const std::string capabilities_str =
+ EncodeIntegerString(service.capabilities);
+
+ discovery::DnsSdTxtRecord txt;
+ Error error;
+ const bool set_txt =
+ !IsError(txt.SetValue(kUniqueIdKey, service.unique_id), &error) &&
+ !IsError(txt.SetValue(kVersionId, service.protocol_version), &error) &&
+ !IsError(txt.SetValue(kCapabilitiesId, capabilities_str), &error) &&
+ !IsError(txt.SetValue(kStatusId, service.status), &error) &&
+ !IsError(txt.SetValue(kFriendlyNameId, service.friendly_name), &error) &&
+ !IsError(txt.SetValue(kModelNameId, service.model_name), &error);
+ OSP_DCHECK(set_txt);
+
+ OSP_DCHECK(service.port);
+ OSP_DCHECK(service.v4_address || service.v6_address);
+ if (service.v4_address && service.v6_address) {
+ return discovery::DnsSdInstanceRecord(
+ instance_id, kCastV2ServiceId, kCastV2DomainId,
+ {service.v4_address, service.port}, {service.v6_address, service.port},
+ std::move(txt));
+ } else {
+ const IPAddress& address =
+ service.v4_address ? service.v4_address : service.v6_address;
+ return discovery::DnsSdInstanceRecord(
+ instance_id, kCastV2ServiceId, kCastV2DomainId, {address, service.port},
+ std::move(txt));
+ }
+}
+
+ErrorOr<ServiceInfo> DnsSdRecordToServiceInfo(
+ const discovery::DnsSdInstanceRecord& instance) {
+ if (instance.service_id() != kCastV2ServiceId) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ ServiceInfo record;
+ record.v4_address = instance.address_v4().address;
+ record.v6_address = instance.address_v6().address;
+ record.port = instance.address_v4().port ? instance.address_v4().port
+ : instance.address_v6().port;
+
+ const auto& txt = instance.txt();
+ std::string capabilities_base64;
+ std::string unique_id;
+ uint8_t status;
+ Error error;
+ if (!TryParseInt(txt, kVersionId, &error, &record.protocol_version) ||
+ !TryParseInt(txt, kStatusId, &error, &status) ||
+ !TryParseString(txt, kFriendlyNameId, &error, &record.friendly_name) ||
+ !TryParseString(txt, kModelNameId, &error, &record.model_name) ||
+ !TryParseString(txt, kCapabilitiesId, &error, &capabilities_base64) ||
+ !TryParseString(txt, kUniqueIdKey, &error, &record.unique_id)) {
+ return error;
+ }
+
+ record.status = static_cast<ReceiverStatus>(status);
+
+ const ErrorOr<uint64_t> capabilities_flags =
+ DecodeIntegerString(capabilities_base64);
+ if (capabilities_flags.is_error()) {
+ return capabilities_flags.error();
+ }
+ record.capabilities = static_cast<ReceiverCapabilities>(
+ capabilities_flags.value() & kCapabilitiesMask);
+
+ return record;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.h b/chromium/third_party/openscreen/src/cast/common/public/service_info.h
new file mode 100644
index 00000000000..a2532440df0
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.h
@@ -0,0 +1,124 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_PUBLIC_SERVICE_INFO_H_
+#define CAST_COMMON_PUBLIC_SERVICE_INFO_H_
+
+#include <memory>
+#include <string>
+
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+namespace cast {
+
+// Constants to identify a CastV2 instance with DNS-SD.
+static constexpr char kCastV2ServiceId[] = "_googlecast._tcp";
+static constexpr char kCastV2DomainId[] = "local";
+
+// Constants to be used as keys when storing data inside of a DNS-SD TXT record.
+static constexpr char kUniqueIdKey[] = "id";
+static constexpr char kVersionId[] = "ve";
+static constexpr char kCapabilitiesId[] = "ca";
+static constexpr char kStatusId[] = "st";
+static constexpr char kFriendlyNameId[] = "fn";
+static constexpr char kModelNameId[] = "mn";
+
+// This represents the ‘st’ flag in the CastV2 TXT record.
+enum ReceiverStatus {
+ // The receiver is idle and does not need to be connected now.
+ kIdle = 0,
+
+ // The receiver is hosting an activity and invites the sender to join. The
+ // receiver should connect to the running activity using the channel
+ // establishment protocol, and then query the activity to determine the next
+ // step, such as showing a description of the activity and prompting the user
+ // to launch the corresponding app.
+ kBusy = 1,
+ kJoin = kBusy
+};
+
+// This represents the ‘ca’ field in the CastV2 spec.
+enum ReceiverCapabilities : uint64_t {
+ kNone = 0x00,
+ kHasVideoOutput = 0x01 << 0,
+ kHasVideoInput = 0x01 << 1,
+ kHasAudioOutput = 0x01 << 2,
+ kHasAudioInput = 0x01 << 3,
+ kIsDevModeEnabled = 0x01 << 4,
+};
+
+static constexpr uint8_t kCurrentCastVersion = 2;
+static constexpr ReceiverCapabilities kDefaultCapabilities =
+ ReceiverCapabilities::kNone;
+
+// This is the top-level service info class for CastV2. It describes a specific
+// service instance.
+// TODO(crbug.com/openscreen/112): Rename this to CastReceiverInfo or similar.
+struct ServiceInfo {
+ // returns the instance id associated with this ServiceInfo instance.
+ const std::string& GetInstanceId() const;
+
+ // Returns whether all fields of this ServiceInfo are valid.
+ bool IsValid() const;
+
+ // Addresses for the service. Present if an address of this address type
+ // exists and empty otherwise.
+ IPAddress v4_address;
+ IPAddress v6_address;
+
+ // Port at which this service can be reached.
+ uint16_t port;
+
+ // A UUID for the Cast receiver. This should be a universally unique
+ // identifier for the receiver, and should (but does not have to be) be stable
+ // across factory resets.
+ std::string unique_id;
+
+ // Cast protocol version supported. Begins at 2 and is incremented by 1 with
+ // each version.
+ uint8_t protocol_version = kCurrentCastVersion;
+
+ // Capabilities supported by this service instance.
+ ReceiverCapabilities capabilities = kDefaultCapabilities;
+
+ // Status of the service instance.
+ ReceiverStatus status = ReceiverStatus::kIdle;
+
+ // The model name of the device, e.g. “Eureka v1”, “Mollie”.
+ std::string model_name;
+
+ // The friendly name of the device, e.g. “Living Room TV".
+ std::string friendly_name;
+
+ private:
+ mutable std::string instance_id_ = "";
+};
+
+inline bool operator==(const ServiceInfo& lhs, const ServiceInfo& rhs) {
+ return lhs.v4_address == rhs.v4_address && lhs.v6_address == rhs.v6_address &&
+ lhs.port == rhs.port && lhs.unique_id == rhs.unique_id &&
+ lhs.protocol_version == rhs.protocol_version &&
+ lhs.capabilities == rhs.capabilities && lhs.status == rhs.status &&
+ lhs.model_name == rhs.model_name &&
+ lhs.friendly_name == rhs.friendly_name;
+}
+
+inline bool operator!=(const ServiceInfo& lhs, const ServiceInfo& rhs) {
+ return !(lhs == rhs);
+}
+
+// Functions responsible for converting between CastV2 and DNS-SD
+// representations of a service instance.
+discovery::DnsSdInstanceRecord ServiceInfoToDnsSdRecord(
+ const ServiceInfo& service);
+
+ErrorOr<ServiceInfo> DnsSdRecordToServiceInfo(
+ const discovery::DnsSdInstanceRecord& service);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_PUBLIC_SERVICE_INFO_H_
diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc
new file mode 100644
index 00000000000..291f3526138
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc
@@ -0,0 +1,209 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/public/service_info.h"
+
+#include "cast/common/public/testing/discovery_utils.h"
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace openscreen {
+namespace cast {
+
+TEST(ServiceInfoTests, ConvertValidFromDnsSd) {
+ std::string instance = "InstanceId";
+ discovery::DnsSdTxtRecord txt = CreateValidTxt();
+ discovery::DnsSdInstanceRecord record(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ ErrorOr<ServiceInfo> info = DnsSdRecordToServiceInfo(record);
+ ASSERT_TRUE(info.is_value());
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_TRUE(info.value().v4_address);
+ EXPECT_EQ(info.value().v4_address, kAddressV4);
+ EXPECT_TRUE(info.value().v6_address);
+ EXPECT_EQ(info.value().v6_address, kAddressV6);
+ EXPECT_EQ(info.value().port, kPort);
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_EQ(info.value().protocol_version, kTestVersion);
+ EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed);
+ EXPECT_EQ(info.value().status, kStatusParsed);
+ EXPECT_EQ(info.value().model_name, kModelName);
+ EXPECT_EQ(info.value().friendly_name, kFriendlyName);
+
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4, txt);
+ ASSERT_FALSE(record.address_v6());
+ info = DnsSdRecordToServiceInfo(record);
+ ASSERT_TRUE(info.is_value());
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_TRUE(info.value().v4_address);
+ EXPECT_EQ(info.value().v4_address, kAddressV4);
+ EXPECT_FALSE(info.value().v6_address);
+ EXPECT_EQ(info.value().port, kPort);
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_EQ(info.value().protocol_version, kTestVersion);
+ EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed);
+ EXPECT_EQ(info.value().status, kStatusParsed);
+ EXPECT_EQ(info.value().model_name, kModelName);
+ EXPECT_EQ(info.value().friendly_name, kFriendlyName);
+
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV6, txt);
+ ASSERT_FALSE(record.address_v4());
+ info = DnsSdRecordToServiceInfo(record);
+ ASSERT_TRUE(info.is_value());
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_FALSE(info.value().v4_address);
+ EXPECT_TRUE(info.value().v6_address);
+ EXPECT_EQ(info.value().v6_address, kAddressV6);
+ EXPECT_EQ(info.value().unique_id, kTestUniqueId);
+ EXPECT_EQ(info.value().protocol_version, kTestVersion);
+ EXPECT_EQ(info.value().capabilities, kCapabilitiesParsed);
+ EXPECT_EQ(info.value().status, kStatusParsed);
+ EXPECT_EQ(info.value().model_name, kModelName);
+ EXPECT_EQ(info.value().friendly_name, kFriendlyName);
+}
+
+TEST(ServiceInfoTests, ConvertInvalidFromDnsSd) {
+ std::string instance = "InstanceId";
+ discovery::DnsSdTxtRecord txt = CreateValidTxt();
+ txt.ClearValue(kUniqueIdKey);
+ discovery::DnsSdInstanceRecord record(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+
+ txt = CreateValidTxt();
+ txt.ClearValue(kVersionId);
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+
+ txt = CreateValidTxt();
+ txt.ClearValue(kCapabilitiesId);
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+
+ txt = CreateValidTxt();
+ txt.ClearValue(kStatusId);
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+
+ txt = CreateValidTxt();
+ txt.ClearValue(kFriendlyNameId);
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+
+ txt = CreateValidTxt();
+ txt.ClearValue(kModelNameId);
+ record = discovery::DnsSdInstanceRecord(instance, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ EXPECT_TRUE(DnsSdRecordToServiceInfo(record).is_error());
+}
+
+TEST(ServiceInfoTests, ConvertValidToDnsSd) {
+ ServiceInfo info;
+ info.v4_address = kAddressV4;
+ info.v6_address = kAddressV6;
+ info.port = kPort;
+ info.unique_id = kTestUniqueId;
+ info.protocol_version = kTestVersion;
+ info.capabilities = kCapabilitiesParsed;
+ info.status = kStatusParsed;
+ info.model_name = kModelName;
+ info.friendly_name = kFriendlyName;
+ discovery::DnsSdInstanceRecord record = ServiceInfoToDnsSdRecord(info);
+ EXPECT_EQ(record.instance_id(), kInstanceId);
+ EXPECT_TRUE(record.address_v4());
+ EXPECT_EQ(record.address_v4(), kEndpointV4);
+ EXPECT_TRUE(record.address_v6());
+ EXPECT_EQ(record.address_v6(), kEndpointV6);
+ CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId);
+ CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString);
+ CompareTxtString(record.txt(), kModelNameId, kModelName);
+ CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName);
+ CompareTxtInt(record.txt(), kVersionId, kTestVersion);
+ CompareTxtInt(record.txt(), kStatusId, kStatus);
+
+ info.v6_address = IPAddress{};
+ record = ServiceInfoToDnsSdRecord(info);
+ EXPECT_TRUE(record.address_v4());
+ EXPECT_EQ(record.address_v4(), kEndpointV4);
+ EXPECT_FALSE(record.address_v6());
+ CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId);
+ CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString);
+ CompareTxtString(record.txt(), kModelNameId, kModelName);
+ CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName);
+ CompareTxtInt(record.txt(), kVersionId, kTestVersion);
+ CompareTxtInt(record.txt(), kStatusId, kStatus);
+
+ info.v6_address = kAddressV6;
+ info.v4_address = IPAddress{};
+ record = ServiceInfoToDnsSdRecord(info);
+ EXPECT_FALSE(record.address_v4());
+ EXPECT_TRUE(record.address_v6());
+ EXPECT_EQ(record.address_v6(), kEndpointV6);
+ CompareTxtString(record.txt(), kUniqueIdKey, kTestUniqueId);
+ CompareTxtString(record.txt(), kCapabilitiesId, kCapabilitiesString);
+ CompareTxtString(record.txt(), kModelNameId, kModelName);
+ CompareTxtString(record.txt(), kFriendlyNameId, kFriendlyName);
+ CompareTxtInt(record.txt(), kVersionId, kTestVersion);
+ CompareTxtInt(record.txt(), kStatusId, kStatus);
+}
+
+TEST(ServiceInfoTests, ConvertInvalidToDnsSd) {
+ ServiceInfo info;
+ info.unique_id = kTestUniqueId;
+ info.protocol_version = kTestVersion;
+ info.capabilities = kCapabilitiesParsed;
+ info.status = kStatusParsed;
+ info.model_name = kModelName;
+ info.friendly_name = kFriendlyName;
+ EXPECT_FALSE(info.IsValid());
+}
+
+TEST(ServiceInfoTests, IdentityChecks) {
+ ServiceInfo info;
+ info.v4_address = kAddressV4;
+ info.v6_address = kAddressV6;
+ info.port = kPort;
+ info.unique_id = kTestUniqueId;
+ info.protocol_version = kTestVersion;
+ info.capabilities = kCapabilitiesParsed;
+ info.status = kStatusParsed;
+ info.model_name = kModelName;
+ info.friendly_name = kFriendlyName;
+ ASSERT_TRUE(info.IsValid());
+ discovery::DnsSdInstanceRecord converted_record =
+ ServiceInfoToDnsSdRecord(info);
+ ErrorOr<ServiceInfo> identity_info =
+ DnsSdRecordToServiceInfo(converted_record);
+ ASSERT_TRUE(identity_info.is_value());
+ EXPECT_EQ(identity_info.value(), info);
+
+ discovery::DnsSdTxtRecord txt = CreateValidTxt();
+ txt.SetValue(kCapabilitiesId, kCapabilitiesString);
+ discovery::DnsSdInstanceRecord record(kInstanceId, kCastV2ServiceId,
+ kCastV2DomainId, kEndpointV4,
+ kEndpointV6, txt);
+ ErrorOr<ServiceInfo> converted_info = DnsSdRecordToServiceInfo(record);
+ ASSERT_TRUE(converted_info.is_value());
+ ASSERT_TRUE(converted_info.value().IsValid());
+ discovery::DnsSdInstanceRecord identity_record =
+ ServiceInfoToDnsSdRecord(converted_info.value());
+ EXPECT_EQ(identity_record, record);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc
new file mode 100644
index 00000000000..4d39041cf4f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.cc
@@ -0,0 +1,56 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/common/public/testing/discovery_utils.h"
+
+#include <sstream>
+
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
+#include "util/stringprintf.h"
+
+namespace openscreen {
+namespace cast {
+
+discovery::DnsSdTxtRecord CreateValidTxt() {
+ discovery::DnsSdTxtRecord txt;
+ txt.SetValue(kUniqueIdKey, kTestUniqueId);
+ txt.SetValue(kVersionId, kTestVersion);
+ txt.SetValue(kCapabilitiesId, kCapabilitiesStringLong);
+ txt.SetValue(kStatusId, kStatus);
+ txt.SetValue(kFriendlyNameId, kFriendlyName);
+ txt.SetValue(kModelNameId, kModelName);
+ return txt;
+}
+
+void CompareTxtString(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ const std::string& expected) {
+ ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key);
+ ASSERT_FALSE(value.is_error())
+ << "expected value: '" << expected << "' for key: '" << key
+ << "'; got error: " << value.error();
+ const std::vector<uint8_t>& data = value.value().get();
+ std::string parsed_value = std::string(data.begin(), data.end());
+ EXPECT_EQ(parsed_value, expected) << "expected value '"
+ << "' for key: '" << key << "'";
+}
+
+void CompareTxtInt(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ uint8_t expected) {
+ ErrorOr<discovery::DnsSdTxtRecord::ValueRef> value = txt.GetValue(key);
+ ASSERT_FALSE(value.is_error())
+ << "key: '" << key << "'' expected: '" << expected << "'";
+ const std::vector<uint8_t>& data = value.value().get();
+ std::string parsed_value = HexEncode(data);
+ ASSERT_EQ(data.size(), size_t{1})
+ << "expected one byte value for key: '" << key << "' got size: '"
+ << data.size() << "' bytes";
+ EXPECT_EQ(data[0], expected)
+ << "expected :" << std::hex << expected << "for key: '" << key
+ << "', got value: '" << parsed_value << "'";
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h
new file mode 100644
index 00000000000..bea238500f4
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/common/public/testing/discovery_utils.h
@@ -0,0 +1,48 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_
+#define CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_
+
+#include "cast/common/public/service_info.h"
+#include "discovery/dnssd/public/dns_sd_txt_record.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+namespace cast {
+
+// Constants used for testing.
+static const IPAddress kAddressV4(192, 168, 0, 0);
+static const IPAddress kAddressV6(1, 2, 3, 4, 5, 6, 7, 8);
+static constexpr uint16_t kPort = 80;
+static const IPEndpoint kEndpointV4{kAddressV4, kPort};
+static const IPEndpoint kEndpointV6{kAddressV6, kPort};
+static constexpr char kTestUniqueId[] = "1234";
+static constexpr char kFriendlyName[] = "Friendly Name 123";
+static constexpr char kModelName[] = "Openscreen";
+static constexpr char kInstanceId[] = "Openscreen-1234";
+static constexpr uint8_t kTestVersion = 0;
+static constexpr char kCapabilitiesString[] = "3";
+static constexpr char kCapabilitiesStringLong[] = "000003";
+static constexpr ReceiverCapabilities kCapabilitiesParsed =
+ static_cast<ReceiverCapabilities>(0x03);
+static constexpr uint8_t kStatus = 0x01;
+static constexpr ReceiverStatus kStatusParsed = ReceiverStatus::kBusy;
+
+discovery::DnsSdTxtRecord CreateValidTxt();
+
+void CompareTxtString(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ const std::string& expected);
+
+void CompareTxtInt(const discovery::DnsSdTxtRecord& txt,
+ const std::string& key,
+ uint8_t expected);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_COMMON_PUBLIC_TESTING_DISCOVERY_UTILS_H_
diff --git a/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn b/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn
new file mode 100644
index 00000000000..f62bf78db77
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/BUILD.gn
@@ -0,0 +1,66 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("channel") {
+ sources = [
+ "channel/device_auth_namespace_handler.cc",
+ "channel/device_auth_namespace_handler.h",
+ "channel/message_util.cc",
+ "channel/message_util.h",
+ "channel/receiver_socket_factory.cc",
+ "channel/receiver_socket_factory.h",
+ ]
+
+ public_deps = [
+ "../../platform",
+ "../../third_party/abseil",
+ "../../third_party/boringssl",
+ "../common:channel",
+ "../common/channel/proto:channel_proto",
+ ]
+
+ deps = [
+ "../../util",
+ "../common:certificate",
+ ]
+}
+
+source_set("test_helpers") {
+ testonly = true
+ sources = [
+ "channel/testing/device_auth_test_helpers.cc",
+ "channel/testing/device_auth_test_helpers.h",
+ ]
+
+ public_deps = [
+ ":channel",
+ "../../third_party/boringssl",
+ "../common:test_helpers",
+ ]
+ deps = [
+ "../../third_party/googletest:gtest",
+ "../common/channel/proto:channel_proto",
+ ]
+}
+
+source_set("unittests") {
+ testonly = true
+ sources = [
+ "channel/device_auth_namespace_handler_unittest.cc",
+ ]
+
+ deps = [
+ ":channel",
+ ":test_helpers",
+ "../../testing/util",
+ "../../third_party/googletest:gmock",
+ "../../third_party/googletest:gtest",
+ "../common:channel",
+ "../common/channel/proto:channel_proto",
+ ]
+
+ data = [
+ "../../test/data/cast/receiver/channel",
+ ]
+}
diff --git a/chromium/third_party/openscreen/src/cast/receiver/DEPS b/chromium/third_party/openscreen/src/cast/receiver/DEPS
index 7bdadde7bd7..a2def1b0c42 100644
--- a/chromium/third_party/openscreen/src/cast/receiver/DEPS
+++ b/chromium/third_party/openscreen/src/cast/receiver/DEPS
@@ -2,6 +2,6 @@
include_rules = [
# libcast receiver code must not depend on the sender.
- '+cast/common/public',
+ '+cast/common',
'+cast/receiver'
]
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc
new file mode 100644
index 00000000000..9e9d5e3f170
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.cc
@@ -0,0 +1,155 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/receiver/channel/device_auth_namespace_handler.h"
+
+#include <openssl/evp.h>
+
+#include "cast/common/certificate/cast_cert_validator.h"
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/virtual_connection.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "platform/base/tls_credentials.h"
+#include "util/crypto/digest_sign.h"
+
+using ::cast::channel::AuthChallenge;
+using ::cast::channel::AuthError;
+using ::cast::channel::AuthResponse;
+using ::cast::channel::CastMessage;
+using ::cast::channel::DeviceAuthMessage;
+using ::cast::channel::HashAlgorithm;
+using ::cast::channel::SignatureAlgorithm;
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+CastMessage GenerateErrorMessage(AuthError::ErrorType error_type) {
+ DeviceAuthMessage message;
+ AuthError* error = message.mutable_error();
+ error->set_error_type(error_type);
+ std::string payload;
+ message.SerializeToString(&payload);
+
+ CastMessage response;
+ response.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ response.set_namespace_(kAuthNamespace);
+ response.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY);
+ response.set_payload_binary(std::move(payload));
+ return response;
+}
+
+} // namespace
+
+DeviceAuthNamespaceHandler::DeviceAuthNamespaceHandler(
+ CredentialsProvider* creds_provider)
+ : creds_provider_(creds_provider) {}
+
+DeviceAuthNamespaceHandler::~DeviceAuthNamespaceHandler() = default;
+
+void DeviceAuthNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ CastMessage message) {
+ if (message.payload_type() !=
+ ::cast::channel::CastMessage_PayloadType_BINARY) {
+ return;
+ }
+ const std::string& payload = message.payload_binary();
+ DeviceAuthMessage device_auth_message;
+ if (!device_auth_message.ParseFromArray(payload.data(), payload.length())) {
+ // TODO(btolsch): Consider all of these cases for future error reporting
+ // mechanism.
+ return;
+ }
+
+ if (!device_auth_message.has_challenge()) {
+ return;
+ }
+
+ if (device_auth_message.has_response() || device_auth_message.has_error()) {
+ return;
+ }
+
+ const VirtualConnection virtual_conn{
+ message.destination_id(), message.source_id(), socket->socket_id()};
+ const AuthChallenge& challenge = device_auth_message.challenge();
+ const SignatureAlgorithm sig_alg = challenge.signature_algorithm();
+ HashAlgorithm hash_alg = challenge.hash_algorithm();
+ // TODO(btolsch): Reconsider supporting SHA1 after further metrics
+ // investigation.
+ if ((sig_alg != ::cast::channel::UNSPECIFIED &&
+ sig_alg != ::cast::channel::RSASSA_PKCS1v15) ||
+ (hash_alg != ::cast::channel::SHA1 &&
+ hash_alg != ::cast::channel::SHA256)) {
+ router->SendMessage(
+ virtual_conn,
+ GenerateErrorMessage(AuthError::SIGNATURE_ALGORITHM_UNAVAILABLE));
+ return;
+ }
+ const EVP_MD* digest =
+ hash_alg == ::cast::channel::SHA256 ? EVP_sha256() : EVP_sha1();
+
+ const absl::Span<const uint8_t> tls_cert_der =
+ creds_provider_->GetCurrentTlsCertAsDer();
+ const DeviceCredentials& device_creds =
+ creds_provider_->GetCurrentDeviceCredentials();
+ if (tls_cert_der.empty() || device_creds.certs.empty() ||
+ !device_creds.private_key) {
+ // TODO(btolsch): Add this to future error reporting.
+ router->SendMessage(virtual_conn,
+ GenerateErrorMessage(AuthError::INTERNAL_ERROR));
+ return;
+ }
+
+ std::unique_ptr<AuthResponse> auth_response(new AuthResponse());
+ auth_response->set_client_auth_certificate(device_creds.certs[0]);
+ for (auto it = device_creds.certs.begin() + 1; it != device_creds.certs.end();
+ ++it) {
+ auth_response->add_intermediate_certificate(*it);
+ }
+ auth_response->set_signature_algorithm(::cast::channel::RSASSA_PKCS1v15);
+ auth_response->set_hash_algorithm(hash_alg);
+ std::string sender_nonce;
+ if (challenge.has_sender_nonce()) {
+ sender_nonce = challenge.sender_nonce();
+ auth_response->set_sender_nonce(sender_nonce);
+ }
+
+ auth_response->set_crl(device_creds.serialized_crl);
+
+ std::vector<uint8_t> to_be_signed;
+ to_be_signed.reserve(sender_nonce.size() + tls_cert_der.size());
+ to_be_signed.insert(to_be_signed.end(), sender_nonce.begin(),
+ sender_nonce.end());
+ to_be_signed.insert(to_be_signed.end(), tls_cert_der.begin(),
+ tls_cert_der.end());
+
+ ErrorOr<std::string> signature =
+ SignData(digest, device_creds.private_key.get(), to_be_signed);
+ if (!signature) {
+ router->SendMessage(virtual_conn,
+ GenerateErrorMessage(AuthError::INTERNAL_ERROR));
+ return;
+ }
+ auth_response->set_signature(std::move(signature.value()));
+
+ DeviceAuthMessage response_auth_message;
+ response_auth_message.set_allocated_response(auth_response.release());
+
+ std::string response_string;
+ response_auth_message.SerializeToString(&response_string);
+ CastMessage response;
+ response.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ response.set_namespace_(kAuthNamespace);
+ response.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY);
+ response.set_payload_binary(std::move(response_string));
+ router->SendMessage(virtual_conn, std::move(response));
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h
new file mode 100644
index 00000000000..ec918e33f16
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler.h
@@ -0,0 +1,57 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_
+#define CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_
+
+#include <openssl/evp.h>
+
+#include <string>
+#include <vector>
+
+#include "absl/types/span.h"
+#include "cast/common/channel/cast_message_handler.h"
+
+namespace openscreen {
+namespace cast {
+
+struct DeviceCredentials {
+ // The device's certificate chain in DER form, where |certs[0]| is the
+ // device's certificate and |certs[certs.size()-1]| is the last intermediate
+ // before a Cast root certificate.
+ std::vector<std::string> certs;
+
+ // The device's private key that corresponds to the certificate in |certs[0]|.
+ bssl::UniquePtr<EVP_PKEY> private_key;
+
+ // If non-empty, this contains a serialized CrlBundle protobuf. This may be
+ // used by the sender as part of verifying |certs|.
+ std::string serialized_crl;
+};
+
+class DeviceAuthNamespaceHandler final : public CastMessageHandler {
+ public:
+ class CredentialsProvider {
+ public:
+ virtual absl::Span<const uint8_t> GetCurrentTlsCertAsDer() = 0;
+ virtual const DeviceCredentials& GetCurrentDeviceCredentials() = 0;
+ };
+
+ // |creds_provider| must outlive |this|.
+ explicit DeviceAuthNamespaceHandler(CredentialsProvider* creds_provider);
+ ~DeviceAuthNamespaceHandler();
+
+ // CastMessageHandler overrides.
+ void OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
+
+ private:
+ CredentialsProvider* const creds_provider_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_RECEIVER_CHANNEL_DEVICE_AUTH_NAMESPACE_HANDLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc
new file mode 100644
index 00000000000..698cbd541d5
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/device_auth_namespace_handler_unittest.cc
@@ -0,0 +1,219 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/receiver/channel/device_auth_namespace_handler.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_socket_error_handler.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "cast/receiver/channel/testing/device_auth_test_helpers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "testing/util/read_file.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using ::cast::channel::AuthResponse;
+using ::cast::channel::CastMessage;
+using ::cast::channel::DeviceAuthMessage;
+using ::cast::channel::SignatureAlgorithm;
+
+using ::testing::_;
+using ::testing::ElementsAreArray;
+using ::testing::Invoke;
+
+class DeviceAuthNamespaceHandlerTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ socket_ = fake_cast_socket_pair_.socket.get();
+ router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.socket));
+ router_.AddHandlerForLocalId(kPlatformReceiverId, &auth_handler_);
+ }
+
+ protected:
+ FakeCastSocketPair fake_cast_socket_pair_;
+ MockSocketErrorHandler mock_error_handler_;
+ CastSocket* socket_;
+
+ StaticCredentialsProvider creds_;
+ VirtualConnectionManager manager_;
+ VirtualConnectionRouter router_{&manager_};
+ DeviceAuthNamespaceHandler auth_handler_{&creds_};
+};
+
+#define TEST_DATA_PREFIX OPENSCREEN_TEST_DATA_DIR "cast/receiver/channel/"
+
+// The tests in this file use a pre-recorded AuthChallenge as input and a
+// matching pre-recorded AuthResponse for verification. This is to make it
+// easier to keep sender and receiver code separate, because the code that would
+// really generate an AuthChallenge and verify an AuthResponse is under
+// //cast/sender. The pre-recorded messages come from an integration test which
+// _does_ properly call both sender and receiver sides, but can optionally
+// record the messages for use in these unit tests. That test is currently
+// under //cast/test. See //cast/test/README.md for more information.
+//
+// The tests generally follow this procedure:
+// 1. Read a fake device certificate chain + TLS certificate from disk.
+// 2. Read a pre-recorded CastMessage proto containing an AuthChallenge.
+// 3. Send this CastMessage over a CastSocket to a DeviceAuthNamespaceHandler.
+// 4. Catch the CastMessage response and check that it has an AuthResponse.
+// 5. Check the AuthResponse against another pre-recorded protobuf.
+
+TEST_F(DeviceAuthNamespaceHandlerTest, AuthResponse) {
+ InitStaticCredentialsFromFiles(
+ &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem",
+ TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem");
+
+ // Send an auth challenge. |auth_handler_| will automatically respond via
+ // |router_| and we will catch the result in |challenge_reply|.
+ CastMessage auth_challenge;
+ const std::string auth_challenge_string =
+ ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb");
+ ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string));
+
+ CastMessage challenge_reply;
+ EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
+ .WillOnce(
+ Invoke([&challenge_reply](CastSocket* socket, CastMessage message) {
+ challenge_reply = std::move(message);
+ }));
+ ASSERT_TRUE(
+ fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge))
+ .ok());
+
+ const std::string auth_response_string =
+ ReadEntireFileToString(TEST_DATA_PREFIX "auth_response.pb");
+ AuthResponse expected_auth_response;
+ ASSERT_TRUE(expected_auth_response.ParseFromString(auth_response_string));
+
+ DeviceAuthMessage auth_message;
+ ASSERT_EQ(challenge_reply.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_BINARY);
+ ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary()));
+ ASSERT_TRUE(auth_message.has_response());
+ ASSERT_FALSE(auth_message.has_challenge());
+ ASSERT_FALSE(auth_message.has_error());
+ const AuthResponse& auth_response = auth_message.response();
+
+ EXPECT_EQ(expected_auth_response.signature(), auth_response.signature());
+ EXPECT_EQ(expected_auth_response.client_auth_certificate(),
+ auth_response.client_auth_certificate());
+ EXPECT_EQ(expected_auth_response.signature_algorithm(),
+ auth_response.signature_algorithm());
+ EXPECT_EQ(expected_auth_response.sender_nonce(),
+ auth_response.sender_nonce());
+ EXPECT_EQ(expected_auth_response.hash_algorithm(),
+ auth_response.hash_algorithm());
+ EXPECT_EQ(expected_auth_response.crl(), auth_response.crl());
+ EXPECT_THAT(
+ auth_response.intermediate_certificate(),
+ ElementsAreArray(expected_auth_response.intermediate_certificate()));
+}
+
+TEST_F(DeviceAuthNamespaceHandlerTest, BadNonce) {
+ InitStaticCredentialsFromFiles(
+ &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem",
+ TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem");
+
+ // Send an auth challenge. |auth_handler_| will automatically respond via
+ // |router_| and we will catch the result in |challenge_reply|.
+ CastMessage auth_challenge;
+ const std::string auth_challenge_string =
+ ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb");
+ ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string));
+
+ // Change the nonce to be different from what was used to record the correct
+ // response originally.
+ DeviceAuthMessage msg;
+ ASSERT_EQ(auth_challenge.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_BINARY);
+ ASSERT_TRUE(msg.ParseFromString(auth_challenge.payload_binary()));
+ ASSERT_TRUE(msg.has_challenge());
+ std::string* nonce = msg.mutable_challenge()->mutable_sender_nonce();
+ (*nonce)[0] = ~(*nonce)[0];
+ std::string new_payload;
+ ASSERT_TRUE(msg.SerializeToString(&new_payload));
+ auth_challenge.set_payload_binary(new_payload);
+
+ CastMessage challenge_reply;
+ EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
+ .WillOnce(
+ Invoke([&challenge_reply](CastSocket* socket, CastMessage message) {
+ challenge_reply = std::move(message);
+ }));
+ ASSERT_TRUE(
+ fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge))
+ .ok());
+
+ const std::string auth_response_string =
+ ReadEntireFileToString(TEST_DATA_PREFIX "auth_response.pb");
+ AuthResponse expected_auth_response;
+ ASSERT_TRUE(expected_auth_response.ParseFromString(auth_response_string));
+
+ DeviceAuthMessage auth_message;
+ ASSERT_EQ(challenge_reply.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_BINARY);
+ ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary()));
+ ASSERT_TRUE(auth_message.has_response());
+ ASSERT_FALSE(auth_message.has_challenge());
+ ASSERT_FALSE(auth_message.has_error());
+ const AuthResponse& auth_response = auth_message.response();
+
+ // NOTE: This is the ultimate result of the nonce-mismatch.
+ EXPECT_NE(expected_auth_response.signature(), auth_response.signature());
+}
+
+TEST_F(DeviceAuthNamespaceHandlerTest, UnsupportedSignatureAlgorithm) {
+ InitStaticCredentialsFromFiles(
+ &creds_, nullptr, nullptr, TEST_DATA_PREFIX "device_key.pem",
+ TEST_DATA_PREFIX "device_chain.pem", TEST_DATA_PREFIX "device_tls.pem");
+
+ // Send an auth challenge. |auth_handler_| will automatically respond via
+ // |router_| and we will catch the result in |challenge_reply|.
+ CastMessage auth_challenge;
+ const std::string auth_challenge_string =
+ ReadEntireFileToString(TEST_DATA_PREFIX "auth_challenge.pb");
+ ASSERT_TRUE(auth_challenge.ParseFromString(auth_challenge_string));
+
+ // Change the signature algorithm an unsupported value.
+ DeviceAuthMessage msg;
+ ASSERT_EQ(auth_challenge.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_BINARY);
+ ASSERT_TRUE(msg.ParseFromString(auth_challenge.payload_binary()));
+ ASSERT_TRUE(msg.has_challenge());
+ msg.mutable_challenge()->set_signature_algorithm(
+ SignatureAlgorithm::RSASSA_PSS);
+ std::string new_payload;
+ ASSERT_TRUE(msg.SerializeToString(&new_payload));
+ auth_challenge.set_payload_binary(new_payload);
+
+ CastMessage challenge_reply;
+ EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
+ .WillOnce(
+ Invoke([&challenge_reply](CastSocket* socket, CastMessage message) {
+ challenge_reply = std::move(message);
+ }));
+ ASSERT_TRUE(
+ fake_cast_socket_pair_.peer_socket->SendMessage(std::move(auth_challenge))
+ .ok());
+
+ DeviceAuthMessage auth_message;
+ ASSERT_EQ(challenge_reply.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_BINARY);
+ ASSERT_TRUE(auth_message.ParseFromString(challenge_reply.payload_binary()));
+ ASSERT_FALSE(auth_message.has_response());
+ ASSERT_FALSE(auth_message.has_challenge());
+ ASSERT_TRUE(auth_message.has_error());
+}
+
+} // namespace
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc
new file mode 100644
index 00000000000..f011c21182d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.cc
@@ -0,0 +1,65 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/receiver/channel/message_util.h"
+
+#include "util/json/json_serialization.h"
+#include "util/json/json_value.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+using ::cast::channel::CastMessage;
+
+namespace {
+
+ErrorOr<CastMessage> CreateAppAvailabilityResponse(
+ int request_id,
+ const std::string& sender_id,
+ const std::string& app_id,
+ AppAvailabilityResult availability_result) {
+ CastMessage availability_response;
+ Json::Value dict(Json::ValueType::objectValue);
+ dict[kMessageKeyRequestId] = request_id;
+ Json::Value availability(Json::ValueType::objectValue);
+ availability[app_id.c_str()] =
+ availability_result == AppAvailabilityResult::kAvailable
+ ? kMessageValueAppAvailable
+ : kMessageValueAppUnavailable;
+ dict[kMessageKeyAvailability] = std::move(availability);
+ ErrorOr<std::string> serialized = json::Stringify(dict);
+ if (!serialized) {
+ return Error::Code::kJsonWriteError;
+ }
+
+ availability_response.set_source_id(kPlatformReceiverId);
+ availability_response.set_destination_id(sender_id);
+ availability_response.set_namespace_(kReceiverNamespace);
+ availability_response.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ availability_response.set_payload_utf8(std::move(serialized.value()));
+ availability_response.set_payload_type(
+ ::cast::channel::CastMessage_PayloadType_STRING);
+ return availability_response;
+}
+
+} // namespace
+
+ErrorOr<CastMessage> CreateAppAvailableResponse(int request_id,
+ const std::string& sender_id,
+ const std::string& app_id) {
+ return CreateAppAvailabilityResponse(request_id, sender_id, app_id,
+ AppAvailabilityResult::kAvailable);
+}
+
+ErrorOr<CastMessage> CreateAppUnavailableResponse(int request_id,
+ const std::string& sender_id,
+ const std::string& app_id) {
+ return CreateAppAvailabilityResponse(request_id, sender_id, app_id,
+ AppAvailabilityResult::kUnavailable);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h
new file mode 100644
index 00000000000..191ff1e913a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/message_util.h
@@ -0,0 +1,30 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "platform/base/error.h"
+
+namespace openscreen {
+namespace cast {
+
+// Creates a message that responds to a previous app availability request with
+// ID |request_id| which declares |app_id| to have availability of either
+// available or unavailable respectively.
+ErrorOr<::cast::channel::CastMessage> CreateAppAvailableResponse(
+ int request_id,
+ const std::string& sender_id,
+ const std::string& app_id);
+ErrorOr<::cast::channel::CastMessage> CreateAppUnavailableResponse(
+ int request_id,
+ const std::string& sender_id,
+ const std::string& app_id);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_RECEIVER_CHANNEL_MESSAGE_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc
new file mode 100644
index 00000000000..f0a642e743b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.cc
@@ -0,0 +1,52 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/receiver/channel/receiver_socket_factory.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+ReceiverSocketFactory::ReceiverSocketFactory(Client* client,
+ CastSocket::Client* socket_client)
+ : client_(client), socket_client_(socket_client) {
+ OSP_DCHECK(client);
+ OSP_DCHECK(socket_client);
+}
+
+ReceiverSocketFactory::~ReceiverSocketFactory() = default;
+
+void ReceiverSocketFactory::OnAccepted(
+ TlsConnectionFactory* factory,
+ std::vector<uint8_t> der_x509_peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ IPEndpoint endpoint = connection->GetRemoteEndpoint();
+ auto socket =
+ std::make_unique<CastSocket>(std::move(connection), socket_client_);
+ client_->OnConnected(this, endpoint, std::move(socket));
+}
+
+void ReceiverSocketFactory::OnConnected(
+ TlsConnectionFactory* factory,
+ std::vector<uint8_t> der_x509_peer_cert,
+ std::unique_ptr<TlsConnection> connection) {
+ OSP_NOTREACHED() << "This factory is accept-only.";
+}
+
+void ReceiverSocketFactory::OnConnectionFailed(
+ TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) {
+ OSP_DVLOG << "Receiving connection from endpoint failed: " << remote_address;
+ client_->OnError(this, Error(Error::Code::kConnectionFailed,
+ "Accepting connection failed."));
+}
+
+void ReceiverSocketFactory::OnError(TlsConnectionFactory* factory,
+ Error error) {
+ client_->OnError(this, error);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h
new file mode 100644
index 00000000000..d8bde8cad6b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/receiver_socket_factory.h
@@ -0,0 +1,51 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_
+#define CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_
+
+#include <vector>
+
+#include "cast/common/channel/cast_socket.h"
+#include "platform/api/tls_connection_factory.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+namespace cast {
+
+class ReceiverSocketFactory final : public TlsConnectionFactory::Client {
+ public:
+ class Client {
+ public:
+ virtual void OnConnected(ReceiverSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) = 0;
+ virtual void OnError(ReceiverSocketFactory* factory, Error error) = 0;
+ };
+
+ // |client| and |socket_client| must outlive |this|.
+ // TODO(btolsch): Add TaskRunner argument just for sequence checking.
+ ReceiverSocketFactory(Client* client, CastSocket::Client* socket_client);
+ ~ReceiverSocketFactory();
+
+ // TlsConnectionFactory::Client overrides.
+ void OnAccepted(TlsConnectionFactory* factory,
+ std::vector<uint8_t> der_x509_peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnected(TlsConnectionFactory* factory,
+ std::vector<uint8_t> der_x509_peer_cert,
+ std::unique_ptr<TlsConnection> connection) override;
+ void OnConnectionFailed(TlsConnectionFactory* factory,
+ const IPEndpoint& remote_address) override;
+ void OnError(TlsConnectionFactory* factory, Error error) override;
+
+ private:
+ Client* const client_;
+ CastSocket::Client* const socket_client_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_RECEIVER_CHANNEL_RECEIVER_SOCKET_FACTORY_H_
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc
new file mode 100644
index 00000000000..51d7ebaa32a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.cc
@@ -0,0 +1,51 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/receiver/channel/testing/device_auth_test_helpers.h"
+
+#include "gtest/gtest.h"
+
+namespace openscreen {
+namespace cast {
+
+void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds,
+ bssl::UniquePtr<X509>* parsed_cert,
+ TrustStore* fake_trust_store,
+ absl::string_view privkey_filename,
+ absl::string_view chain_filename,
+ absl::string_view tls_filename) {
+ auto private_key = testing::ReadKeyFromPemFile(privkey_filename);
+ ASSERT_TRUE(private_key);
+ std::vector<std::string> certs =
+ testing::ReadCertificatesFromPemFile(chain_filename);
+ ASSERT_GT(certs.size(), 1u);
+
+ // Use the root of the chain as the trust store for the test.
+ auto* data = reinterpret_cast<const uint8_t*>(certs.back().data());
+ auto fake_root =
+ bssl::UniquePtr<X509>(d2i_X509(nullptr, &data, certs.back().size()));
+ ASSERT_TRUE(fake_root);
+ certs.pop_back();
+ if (fake_trust_store) {
+ fake_trust_store->certs.emplace_back(fake_root.release());
+ }
+
+ creds->device_creds = DeviceCredentials{
+ std::move(certs), std::move(private_key), std::string()};
+
+ const std::vector<std::string> tls_cert =
+ testing::ReadCertificatesFromPemFile(tls_filename);
+ ASSERT_EQ(tls_cert.size(), 1u);
+ data = reinterpret_cast<const uint8_t*>(tls_cert[0].data());
+ if (parsed_cert) {
+ *parsed_cert =
+ bssl::UniquePtr<X509>(d2i_X509(nullptr, &data, tls_cert[0].size()));
+ ASSERT_TRUE(*parsed_cert);
+ }
+ const auto* begin = reinterpret_cast<const uint8_t*>(tls_cert[0].data());
+ creds->tls_cert_der.assign(begin, begin + tls_cert[0].size());
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h
new file mode 100644
index 00000000000..65ddccf3b9e
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/receiver/channel/testing/device_auth_test_helpers.h
@@ -0,0 +1,46 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_
+#define CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_
+
+#include <openssl/x509.h>
+
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "cast/common/certificate/testing/test_helpers.h"
+#include "cast/receiver/channel/device_auth_namespace_handler.h"
+
+namespace openscreen {
+namespace cast {
+
+class StaticCredentialsProvider final
+ : public DeviceAuthNamespaceHandler::CredentialsProvider {
+ public:
+ StaticCredentialsProvider() = default;
+ ~StaticCredentialsProvider() = default;
+
+ absl::Span<const uint8_t> GetCurrentTlsCertAsDer() override {
+ return absl::Span<uint8_t>(tls_cert_der);
+ }
+ const DeviceCredentials& GetCurrentDeviceCredentials() override {
+ return device_creds;
+ }
+
+ DeviceCredentials device_creds;
+ std::vector<uint8_t> tls_cert_der;
+};
+
+void InitStaticCredentialsFromFiles(StaticCredentialsProvider* creds,
+ bssl::UniquePtr<X509>* parsed_cert,
+ TrustStore* fake_trust_store,
+ absl::string_view privkey_filename,
+ absl::string_view chain_filename,
+ absl::string_view tls_filename);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_RECEIVER_CHANNEL_TESTING_DEVICE_AUTH_TEST_HELPERS_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/BUILD.gn b/chromium/third_party/openscreen/src/cast/sender/BUILD.gn
index 31c21215547..f500f766e86 100644
--- a/chromium/third_party/openscreen/src/cast/sender/BUILD.gn
+++ b/chromium/third_party/openscreen/src/cast/sender/BUILD.gn
@@ -13,7 +13,7 @@ source_set("channel") {
]
deps = [
- "../../util",
+ "../common:channel",
"../common/certificate/proto:certificate_proto",
"../common/channel/proto:channel_proto",
]
@@ -21,19 +21,75 @@ source_set("channel") {
public_deps = [
"../../platform",
"../../third_party/boringssl",
+ "../../util",
+ "../common:certificate",
+ "../common:channel",
+ ]
+}
+
+source_set("sender") {
+ sources = [
+ "cast_app_availability_tracker.cc",
+ "cast_app_availability_tracker.h",
+ "cast_app_discovery_service_impl.cc",
+ "cast_app_discovery_service_impl.h",
+ "cast_platform_client.cc",
+ "cast_platform_client.h",
+ "public/cast_app_discovery_service.cc",
+ "public/cast_app_discovery_service.h",
+ "public/cast_media_source.cc",
+ "public/cast_media_source.h",
+ ]
+
+ public_deps = [
+ ":channel",
+ "../../platform",
+ "../../third_party/abseil",
+ "../../util",
+ "../common:channel",
+ "../common:public",
+ ]
+}
+
+source_set("test_helpers") {
+ testonly = true
+ sources = [
+ "testing/test_helpers.cc",
+ "testing/test_helpers.h",
+ ]
+
+ deps = [
+ "../../third_party/googletest:gtest",
+ "../../util",
+ "../common:channel",
+ "../receiver:channel",
+ ]
+
+ public_deps = [
+ ":channel",
]
}
source_set("unittests") {
testonly = true
sources = [
+ "cast_app_availability_tracker_unittest.cc",
+ "cast_app_discovery_service_impl_unittest.cc",
+ "cast_platform_client_unittest.cc",
"channel/cast_auth_util_unittest.cc",
]
deps = [
":channel",
+ ":sender",
+ ":test_helpers",
"../../platform",
+ "../../platform:test",
+ "../../testing/util",
+ "../../third_party/googletest:gmock",
"../../third_party/googletest:gtest",
+ "../../util",
+ "../common:test_helpers",
"../common/certificate/proto:certificate_proto",
"../common/certificate/proto:certificate_unittest_proto",
]
diff --git a/chromium/third_party/openscreen/src/cast/sender/DEPS b/chromium/third_party/openscreen/src/cast/sender/DEPS
index 7ab7a51a926..48f4daeefbe 100644
--- a/chromium/third_party/openscreen/src/cast/sender/DEPS
+++ b/chromium/third_party/openscreen/src/cast/sender/DEPS
@@ -1,7 +1,7 @@
# -*- Mode: Python; -*-
include_rules = [
- # libcast sender code must not depend on the receiver.
- '+cast/common',
- '+cast/sender'
+ # libcast sender code must not depend on the receiver.
+ '+cast/common',
+ '+cast/sender',
]
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc
new file mode 100644
index 00000000000..6c21c799d29
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.cc
@@ -0,0 +1,165 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_app_availability_tracker.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+CastAppAvailabilityTracker::CastAppAvailabilityTracker() = default;
+CastAppAvailabilityTracker::~CastAppAvailabilityTracker() = default;
+
+std::vector<std::string> CastAppAvailabilityTracker::RegisterSource(
+ const CastMediaSource& source) {
+ if (registered_sources_.find(source.source_id()) !=
+ registered_sources_.end()) {
+ return {};
+ }
+
+ registered_sources_.emplace(source.source_id(), source);
+
+ std::vector<std::string> new_app_ids;
+ for (const std::string& app_id : source.app_ids()) {
+ if (++registration_count_by_app_id_[app_id] == 1) {
+ new_app_ids.push_back(app_id);
+ }
+ }
+ return new_app_ids;
+}
+
+void CastAppAvailabilityTracker::UnregisterSource(
+ const CastMediaSource& source) {
+ UnregisterSource(source.source_id());
+}
+
+void CastAppAvailabilityTracker::UnregisterSource(
+ const std::string& source_id) {
+ auto it = registered_sources_.find(source_id);
+ if (it == registered_sources_.end()) {
+ return;
+ }
+
+ for (const std::string& app_id : it->second.app_ids()) {
+ auto count_it = registration_count_by_app_id_.find(app_id);
+ OSP_DCHECK(count_it != registration_count_by_app_id_.end());
+ if (--(count_it->second) == 0) {
+ registration_count_by_app_id_.erase(count_it);
+ }
+ }
+
+ registered_sources_.erase(it);
+}
+
+std::vector<CastMediaSource> CastAppAvailabilityTracker::UpdateAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id,
+ AppAvailability availability) {
+ auto& availabilities = app_availabilities_[device_id];
+ auto it = availabilities.find(app_id);
+
+ AppAvailabilityResult old_availability = it == availabilities.end()
+ ? AppAvailabilityResult::kUnknown
+ : it->second.availability;
+ AppAvailabilityResult new_availability = availability.availability;
+
+ // Updated if status changes from/to kAvailable.
+ bool updated = (old_availability == AppAvailabilityResult::kAvailable ||
+ new_availability == AppAvailabilityResult::kAvailable) &&
+ old_availability != new_availability;
+ availabilities[app_id] = availability;
+
+ if (!updated) {
+ return {};
+ }
+
+ std::vector<CastMediaSource> affected_sources;
+ for (const auto& source : registered_sources_) {
+ if (source.second.ContainsAppId(app_id)) {
+ affected_sources.push_back(source.second);
+ }
+ }
+ return affected_sources;
+}
+
+std::vector<CastMediaSource> CastAppAvailabilityTracker::RemoveResultsForDevice(
+ const std::string& device_id) {
+ auto affected_sources = GetSupportedSources(device_id);
+ app_availabilities_.erase(device_id);
+ return affected_sources;
+}
+
+std::vector<CastMediaSource> CastAppAvailabilityTracker::GetSupportedSources(
+ const std::string& device_id) const {
+ auto it = app_availabilities_.find(device_id);
+ if (it == app_availabilities_.end()) {
+ return std::vector<CastMediaSource>();
+ }
+
+ // Find all app IDs that are available on the device.
+ std::vector<std::string> supported_app_ids;
+ for (const auto& availability : it->second) {
+ if (availability.second.availability == AppAvailabilityResult::kAvailable) {
+ supported_app_ids.push_back(availability.first);
+ }
+ }
+
+ // Find all registered sources whose query results contain the device ID.
+ std::vector<CastMediaSource> sources;
+ for (const auto& source : registered_sources_) {
+ if (source.second.ContainsAnyAppIdFrom(supported_app_ids)) {
+ sources.push_back(source.second);
+ }
+ }
+ return sources;
+}
+
+CastAppAvailabilityTracker::AppAvailability
+CastAppAvailabilityTracker::GetAvailability(const std::string& device_id,
+ const std::string& app_id) const {
+ auto availabilities_it = app_availabilities_.find(device_id);
+ if (availabilities_it == app_availabilities_.end()) {
+ return {AppAvailabilityResult::kUnknown, Clock::time_point{}};
+ }
+
+ const auto& availability_map = availabilities_it->second;
+ auto availability_it = availability_map.find(app_id);
+ if (availability_it == availability_map.end()) {
+ return {AppAvailabilityResult::kUnknown, Clock::time_point{}};
+ }
+
+ return availability_it->second;
+}
+
+std::vector<std::string> CastAppAvailabilityTracker::GetRegisteredApps() const {
+ std::vector<std::string> registered_apps;
+ for (const auto& app_ids_and_count : registration_count_by_app_id_) {
+ registered_apps.push_back(app_ids_and_count.first);
+ }
+
+ return registered_apps;
+}
+
+std::vector<std::string> CastAppAvailabilityTracker::GetAvailableDevices(
+ const CastMediaSource& source) const {
+ std::vector<std::string> device_ids;
+ // For each device, check if there is at least one available app in |source|.
+ for (const auto& availabilities : app_availabilities_) {
+ for (const std::string& app_id : source.app_ids()) {
+ const auto& availabilities_map = availabilities.second;
+ auto availability_it = availabilities_map.find(app_id);
+ if (availability_it != availabilities_map.end() &&
+ availability_it->second.availability ==
+ AppAvailabilityResult::kAvailable) {
+ device_ids.push_back(availabilities.first);
+ break;
+ }
+ }
+ }
+ return device_ids;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h
new file mode 100644
index 00000000000..c0bded96334
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker.h
@@ -0,0 +1,125 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_
+#define CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "cast/sender/channel/message_util.h"
+#include "cast/sender/public/cast_media_source.h"
+#include "platform/api/time.h"
+
+namespace openscreen {
+namespace cast {
+
+// Tracks device queries and their extracted Cast app IDs and their
+// availabilities on discovered devices.
+// Example usage:
+///
+// (1) A page is interested in a Cast URL (e.g. by creating a
+// PresentationRequest with the URL) like "cast:foo". To register the source to
+// be tracked:
+// CastAppAvailabilityTracker tracker;
+// auto source = CastMediaSource::From("cast:foo");
+// auto new_app_ids = tracker.RegisterSource(source.value());
+//
+// (2) The set of app IDs returned by the tracker can then be used by the caller
+// to send an app availability request to each of the discovered devices.
+//
+// (3) Once the caller knows the availability value for a (device, app) pair, it
+// may inform the tracker to update its results:
+// auto affected_sources =
+// tracker.UpdateAppAvailability(device_id, app_id, {availability, now});
+//
+// (4) The tracker returns a subset of discovered sources that were affected by
+// the update. The caller can then call |GetAvailableDevices()| to get the
+// updated results for each affected source.
+//
+// (5a): At any time, the caller may call |RemoveResultsForDevice()| to remove
+// cached results pertaining to the device, when it detects that a device is
+// removed or no longer valid.
+//
+// (5b): At any time, the caller may call |GetAvailableDevices()| (even before
+// the source is registered) to determine if there are cached results available.
+// TODO(crbug.com/openscreen/112): Device -> Receiver renaming.
+class CastAppAvailabilityTracker {
+ public:
+ // The result of an app availability request and the time when it is obtained.
+ struct AppAvailability {
+ AppAvailabilityResult availability;
+ Clock::time_point time;
+ };
+
+ CastAppAvailabilityTracker();
+ ~CastAppAvailabilityTracker();
+
+ CastAppAvailabilityTracker(const CastAppAvailabilityTracker&) = delete;
+ CastAppAvailabilityTracker& operator=(const CastAppAvailabilityTracker&) =
+ delete;
+
+ // Registers |source| with the tracker. Returns a list of new app IDs that
+ // were previously not known to the tracker.
+ std::vector<std::string> RegisterSource(const CastMediaSource& source);
+
+ // Unregisters the source given by |source| or |source_id| with the tracker.
+ void UnregisterSource(const std::string& source_id);
+ void UnregisterSource(const CastMediaSource& source);
+
+ // Updates the availability of |app_id| on |device_id| to |availability|.
+ // Returns a list of registered CastMediaSources for which the set of
+ // available devices might have been updated by this call. The caller should
+ // call |GetAvailableDevices| with the returned CastMediaSources to get the
+ // updated lists.
+ std::vector<CastMediaSource> UpdateAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id,
+ AppAvailability availability);
+
+ // Removes all results associated with |device_id|, i.e. when the device
+ // becomes invalid. Returns a list of registered CastMediaSources for which
+ // the set of available devices might have been updated by this call. The
+ // caller should call |GetAvailableDevices| with the returned CastMediaSources
+ // to get the updated lists.
+ std::vector<CastMediaSource> RemoveResultsForDevice(
+ const std::string& device_id);
+
+ // Returns a list of registered CastMediaSources supported by |device_id|.
+ std::vector<CastMediaSource> GetSupportedSources(
+ const std::string& device_id) const;
+
+ // Returns the availability for |app_id| on |device_id| and the time at which
+ // the availability was determined. If availability is kUnknown, then the time
+ // may be null (e.g. if an availability request was never sent).
+ AppAvailability GetAvailability(const std::string& device_id,
+ const std::string& app_id) const;
+
+ // Returns a list of registered app IDs.
+ std::vector<std::string> GetRegisteredApps() const;
+
+ // Returns a list of device IDs compatible with |source|, using the current
+ // availability info.
+ std::vector<std::string> GetAvailableDevices(
+ const CastMediaSource& source) const;
+
+ private:
+ // App ID to availability.
+ using AppAvailabilityMap = std::map<std::string, AppAvailability>;
+
+ // Registered sources and corresponding CastMediaSources.
+ std::map<std::string, CastMediaSource> registered_sources_;
+
+ // App IDs tracked and the number of registered sources containing them.
+ std::map<std::string, int> registration_count_by_app_id_;
+
+ // IDs and app availabilities of known devices.
+ std::map<std::string, AppAvailabilityMap> app_availabilities_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_CAST_APP_AVAILABILITY_TRACKER_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc
new file mode 100644
index 00000000000..b45d356365f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_availability_tracker_unittest.cc
@@ -0,0 +1,159 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_app_availability_tracker.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using ::testing::UnorderedElementsAreArray;
+
+MATCHER_P(CastMediaSourcesEqual, expected, "") {
+ if (expected.size() != arg.size())
+ return false;
+ return std::equal(
+ expected.begin(), expected.end(), arg.begin(),
+ [](const CastMediaSource& source1, const CastMediaSource& source2) {
+ return source1.source_id() == source2.source_id();
+ });
+}
+
+} // namespace
+
+class CastAppAvailabilityTrackerTest : public ::testing::Test {
+ public:
+ CastAppAvailabilityTrackerTest() : clock_(Clock::now()) {}
+ ~CastAppAvailabilityTrackerTest() override = default;
+
+ Clock::time_point Now() const { return clock_.now(); }
+
+ protected:
+ FakeClock clock_;
+ CastAppAvailabilityTracker tracker_;
+};
+
+TEST_F(CastAppAvailabilityTrackerTest, RegisterSource) {
+ CastMediaSource source1("cast:AAA?clientId=1", {"AAA"});
+ CastMediaSource source2("cast:AAA?clientId=2", {"AAA"});
+
+ std::vector<std::string> expected_app_ids = {"AAA"};
+ EXPECT_EQ(expected_app_ids, tracker_.RegisterSource(source1));
+
+ EXPECT_EQ(std::vector<std::string>{}, tracker_.RegisterSource(source1));
+ EXPECT_EQ(std::vector<std::string>{}, tracker_.RegisterSource(source2));
+
+ tracker_.UnregisterSource(source1);
+ tracker_.UnregisterSource(source2);
+
+ EXPECT_EQ(expected_app_ids, tracker_.RegisterSource(source1));
+ EXPECT_EQ(expected_app_ids, tracker_.GetRegisteredApps());
+}
+
+TEST_F(CastAppAvailabilityTrackerTest, RegisterSourceReturnsMultipleAppIds) {
+ CastMediaSource source1("urn:x-org.chromium.media:source:tab:1",
+ {"0F5096E8", "85CDB22F"});
+
+ // Mirorring app ids.
+ std::vector<std::string> expected_app_ids = {"0F5096E8", "85CDB22F"};
+ EXPECT_THAT(tracker_.RegisterSource(source1),
+ UnorderedElementsAreArray(expected_app_ids));
+ EXPECT_THAT(tracker_.GetRegisteredApps(),
+ UnorderedElementsAreArray(expected_app_ids));
+}
+
+TEST_F(CastAppAvailabilityTrackerTest, MultipleAppIdsAlreadyTrackingOne) {
+ // One of the mirroring app IDs.
+ CastMediaSource source1("cast:0F5096E8?clientId=123", {"0F5096E8"});
+
+ std::vector<std::string> new_app_ids = {"0F5096E8"};
+ std::vector<std::string> registered_app_ids = {"0F5096E8"};
+ EXPECT_EQ(new_app_ids, tracker_.RegisterSource(source1));
+ EXPECT_EQ(registered_app_ids, tracker_.GetRegisteredApps());
+
+ CastMediaSource source2("urn:x-org.chromium.media:source:tab:1",
+ {"0F5096E8", "85CDB22F"});
+
+ new_app_ids = {"85CDB22F"};
+ registered_app_ids = {"0F5096E8", "85CDB22F"};
+
+ EXPECT_EQ(new_app_ids, tracker_.RegisterSource(source2));
+ EXPECT_THAT(tracker_.GetRegisteredApps(),
+ UnorderedElementsAreArray(registered_app_ids));
+}
+
+TEST_F(CastAppAvailabilityTrackerTest, UpdateAppAvailability) {
+ CastMediaSource source1("cast:AAA?clientId=1", {"AAA"});
+ CastMediaSource source2("cast:AAA?clientId=2", {"AAA"});
+ CastMediaSource source3("cast:BBB?clientId=3", {"BBB"});
+
+ tracker_.RegisterSource(source3);
+
+ // |source3| not affected.
+ EXPECT_THAT(
+ tracker_.UpdateAppAvailability(
+ "deviceId1", "AAA", {AppAvailabilityResult::kAvailable, Now()}),
+ CastMediaSourcesEqual(std::vector<CastMediaSource>()));
+
+ std::vector<std::string> devices_1 = {"deviceId1"};
+ std::vector<std::string> devices_1_2 = {"deviceId1", "deviceId2"};
+ std::vector<CastMediaSource> sources_1 = {source1};
+ std::vector<CastMediaSource> sources_1_2 = {source1, source2};
+
+ // Tracker returns available devices even though sources aren't registered.
+ EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1));
+ EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2));
+ EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty());
+
+ tracker_.RegisterSource(source1);
+ // Only |source1| is registered for this app.
+ EXPECT_THAT(
+ tracker_.UpdateAppAvailability(
+ "deviceId2", "AAA", {AppAvailabilityResult::kAvailable, Now()}),
+ CastMediaSourcesEqual(sources_1));
+ EXPECT_THAT(tracker_.GetAvailableDevices(source1),
+ UnorderedElementsAreArray(devices_1_2));
+ EXPECT_THAT(tracker_.GetAvailableDevices(source2),
+ UnorderedElementsAreArray(devices_1_2));
+ EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty());
+
+ tracker_.RegisterSource(source2);
+ EXPECT_THAT(
+ tracker_.UpdateAppAvailability(
+ "deviceId2", "AAA", {AppAvailabilityResult::kUnavailable, Now()}),
+ CastMediaSourcesEqual(sources_1_2));
+ EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1));
+ EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2));
+ EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty());
+}
+
+TEST_F(CastAppAvailabilityTrackerTest, RemoveResultsForDevice) {
+ CastMediaSource source1("cast:AAA?clientId=1", {"AAA"});
+
+ tracker_.UpdateAppAvailability("deviceId1", "AAA",
+ {AppAvailabilityResult::kAvailable, Now()});
+ EXPECT_EQ(AppAvailabilityResult::kAvailable,
+ tracker_.GetAvailability("deviceId1", "AAA").availability);
+
+ std::vector<std::string> expected_device_ids = {"deviceId1"};
+ EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1));
+
+ // Unrelated device ID.
+ tracker_.RemoveResultsForDevice("deviceId2");
+ EXPECT_EQ(AppAvailabilityResult::kAvailable,
+ tracker_.GetAvailability("deviceId1", "AAA").availability);
+ EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1));
+
+ tracker_.RemoveResultsForDevice("deviceId1");
+ EXPECT_EQ(AppAvailabilityResult::kUnknown,
+ tracker_.GetAvailability("deviceId1", "AAA").availability);
+ EXPECT_EQ(std::vector<std::string>{}, tracker_.GetAvailableDevices(source1));
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc
new file mode 100644
index 00000000000..69028aacff7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.cc
@@ -0,0 +1,211 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_app_discovery_service_impl.h"
+
+#include <algorithm>
+#include <chrono>
+
+#include "cast/sender/public/cast_media_source.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// The minimum time that must elapse before an app availability result can be
+// force refreshed.
+static constexpr std::chrono::minutes kRefreshThreshold =
+ std::chrono::minutes(1);
+
+} // namespace
+
+CastAppDiscoveryServiceImpl::CastAppDiscoveryServiceImpl(
+ CastPlatformClient* platform_client,
+ ClockNowFunctionPtr clock)
+ : platform_client_(platform_client), clock_(clock), weak_factory_(this) {
+ OSP_DCHECK(platform_client_);
+ OSP_DCHECK(clock_);
+}
+
+CastAppDiscoveryServiceImpl::~CastAppDiscoveryServiceImpl() {
+ OSP_CHECK_EQ(avail_queries_.size(), 0u);
+}
+
+CastAppDiscoveryService::Subscription
+CastAppDiscoveryServiceImpl::StartObservingAvailability(
+ const CastMediaSource& source,
+ AvailabilityCallback callback) {
+ const std::string& source_id = source.source_id();
+
+ // Return cached results immediately, if available.
+ std::vector<std::string> cached_device_ids =
+ availability_tracker_.GetAvailableDevices(source);
+ if (!cached_device_ids.empty()) {
+ callback(source, GetReceiversByIds(cached_device_ids));
+ }
+
+ auto& callbacks = avail_queries_[source_id];
+ uint32_t query_id = GetNextAvailabilityQueryId();
+ callbacks.push_back({query_id, std::move(callback)});
+ if (callbacks.size() == 1) {
+ // NOTE: Even though we retain availability results for an app unregistered
+ // from the tracker, we will refresh the results when the app is
+ // re-registered.
+ std::vector<std::string> new_app_ids =
+ availability_tracker_.RegisterSource(source);
+ for (const auto& app_id : new_app_ids) {
+ for (const auto& entry : receivers_by_id_) {
+ RequestAppAvailability(entry.first, app_id);
+ }
+ }
+ }
+
+ return MakeSubscription(this, query_id);
+}
+
+void CastAppDiscoveryServiceImpl::Refresh() {
+ const auto app_ids = availability_tracker_.GetRegisteredApps();
+ for (const auto& entry : receivers_by_id_) {
+ for (const auto& app_id : app_ids) {
+ RequestAppAvailability(entry.first, app_id);
+ }
+ }
+}
+
+void CastAppDiscoveryServiceImpl::AddOrUpdateReceiver(
+ const ServiceInfo& receiver) {
+ const std::string& device_id = receiver.unique_id;
+ receivers_by_id_[device_id] = receiver;
+
+ // Any queries that currently contain this receiver should be updated.
+ UpdateAvailabilityQueries(
+ availability_tracker_.GetSupportedSources(device_id));
+
+ for (const std::string& app_id : availability_tracker_.GetRegisteredApps()) {
+ RequestAppAvailability(device_id, app_id);
+ }
+}
+
+void CastAppDiscoveryServiceImpl::RemoveReceiver(const ServiceInfo& receiver) {
+ const std::string& device_id = receiver.unique_id;
+ receivers_by_id_.erase(device_id);
+ UpdateAvailabilityQueries(
+ availability_tracker_.RemoveResultsForDevice(device_id));
+}
+
+void CastAppDiscoveryServiceImpl::RequestAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id) {
+ if (ShouldRefreshAppAvailability(device_id, app_id, clock_())) {
+ platform_client_->RequestAppAvailability(
+ device_id, app_id,
+ [self = weak_factory_.GetWeakPtr(), device_id](
+ const std::string& app_id, AppAvailabilityResult availability) {
+ if (self) {
+ self->UpdateAppAvailability(device_id, app_id, availability);
+ }
+ });
+ }
+}
+
+void CastAppDiscoveryServiceImpl::UpdateAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id,
+ AppAvailabilityResult availability) {
+ if (receivers_by_id_.find(device_id) == receivers_by_id_.end()) {
+ return;
+ }
+
+ OSP_DVLOG << "App " << app_id << " on receiver " << device_id << " is "
+ << ToString(availability);
+
+ UpdateAvailabilityQueries(availability_tracker_.UpdateAppAvailability(
+ device_id, app_id, {availability, clock_()}));
+}
+
+void CastAppDiscoveryServiceImpl::UpdateAvailabilityQueries(
+ const std::vector<CastMediaSource>& sources) {
+ for (const auto& source : sources) {
+ const std::string& source_id = source.source_id();
+ auto it = avail_queries_.find(source_id);
+ if (it == avail_queries_.end())
+ continue;
+ std::vector<std::string> device_ids =
+ availability_tracker_.GetAvailableDevices(source);
+ std::vector<ServiceInfo> receivers = GetReceiversByIds(device_ids);
+ for (const auto& callback : it->second) {
+ callback.callback(source, receivers);
+ }
+ }
+}
+
+std::vector<ServiceInfo> CastAppDiscoveryServiceImpl::GetReceiversByIds(
+ const std::vector<std::string>& device_ids) const {
+ std::vector<ServiceInfo> receivers;
+ for (const std::string& device_id : device_ids) {
+ auto entry = receivers_by_id_.find(device_id);
+ if (entry != receivers_by_id_.end()) {
+ receivers.push_back(entry->second);
+ }
+ }
+ return receivers;
+}
+
+bool CastAppDiscoveryServiceImpl::ShouldRefreshAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id,
+ Clock::time_point now) const {
+ // TODO(btolsch): Consider an exponential backoff mechanism instead.
+ // Receivers will typically respond with "unavailable" immediately after boot
+ // and then become available 10-30 seconds later.
+ auto availability = availability_tracker_.GetAvailability(device_id, app_id);
+ switch (availability.availability) {
+ case AppAvailabilityResult::kAvailable:
+ return false;
+ case AppAvailabilityResult::kUnavailable:
+ return (now - availability.time) > kRefreshThreshold;
+ // TODO(btolsch): Should there be a background task for periodically
+ // refreshing kUnknown (or even kUnavailable) results?
+ case AppAvailabilityResult::kUnknown:
+ return true;
+ }
+
+ OSP_NOTREACHED();
+ return false;
+}
+
+uint32_t CastAppDiscoveryServiceImpl::GetNextAvailabilityQueryId() {
+ if (free_query_ids_.empty()) {
+ return next_avail_query_id_++;
+ } else {
+ uint32_t id = free_query_ids_.back();
+ free_query_ids_.pop_back();
+ return id;
+ }
+}
+
+void CastAppDiscoveryServiceImpl::RemoveAvailabilityCallback(uint32_t id) {
+ for (auto entry = avail_queries_.begin(); entry != avail_queries_.end();
+ ++entry) {
+ const std::string& source_id = entry->first;
+ auto& callbacks = entry->second;
+ auto it =
+ std::find_if(callbacks.begin(), callbacks.end(),
+ [id](const AvailabilityCallbackEntry& callback_entry) {
+ return callback_entry.id == id;
+ });
+ if (it != callbacks.end()) {
+ callbacks.erase(it);
+ if (callbacks.empty()) {
+ availability_tracker_.UnregisterSource(source_id);
+ avail_queries_.erase(entry);
+ }
+ return;
+ }
+ }
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h
new file mode 100644
index 00000000000..4093311af35
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl.h
@@ -0,0 +1,99 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_
+#define CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "cast/common/public/service_info.h"
+#include "cast/sender/cast_app_availability_tracker.h"
+#include "cast/sender/cast_platform_client.h"
+#include "cast/sender/public/cast_app_discovery_service.h"
+#include "platform/api/time.h"
+#include "util/weak_ptr.h"
+
+namespace openscreen {
+namespace cast {
+
+// Keeps track of availability queries, receives receiver updates, and issues
+// app availability requests based on these signals.
+class CastAppDiscoveryServiceImpl : public CastAppDiscoveryService {
+ public:
+ // |platform_client| must outlive |this|.
+ CastAppDiscoveryServiceImpl(CastPlatformClient* platform_client,
+ ClockNowFunctionPtr clock);
+ ~CastAppDiscoveryServiceImpl() override;
+
+ // CastAppDiscoveryService implementation.
+ Subscription StartObservingAvailability(
+ const CastMediaSource& source,
+ AvailabilityCallback callback) override;
+
+ // Reissues app availability requests for currently registered (device_id,
+ // app_id) pairs whose status is kUnavailable or kUnknown.
+ void Refresh() override;
+
+ void AddOrUpdateReceiver(const ServiceInfo& receiver);
+ void RemoveReceiver(const ServiceInfo& receiver);
+
+ private:
+ struct AvailabilityCallbackEntry {
+ uint32_t id;
+ AvailabilityCallback callback;
+ };
+
+ // Issues an app availability request for |app_id| to the receiver given by
+ // |device_id|.
+ void RequestAppAvailability(const std::string& device_id,
+ const std::string& app_id);
+
+ // Updates the availability result for |device_id| and |app_id| with |result|,
+ // and notifies callbacks with updated availability query results.
+ void UpdateAppAvailability(const std::string& device_id,
+ const std::string& app_id,
+ AppAvailabilityResult result);
+
+ // Updates the availability query results for |sources|.
+ void UpdateAvailabilityQueries(const std::vector<CastMediaSource>& sources);
+
+ std::vector<ServiceInfo> GetReceiversByIds(
+ const std::vector<std::string>& device_ids) const;
+
+ // Returns true if an app availability request should be issued for
+ // |device_id| and |app_id|. |now| is used for checking whether previously
+ // cached results should be refreshed.
+ bool ShouldRefreshAppAvailability(const std::string& device_id,
+ const std::string& app_id,
+ Clock::time_point now) const;
+
+ uint32_t GetNextAvailabilityQueryId();
+
+ void RemoveAvailabilityCallback(uint32_t id) override;
+
+ std::map<std::string, ServiceInfo> receivers_by_id_;
+
+ // Registered availability queries and their associated callbacks keyed by
+ // media source IDs.
+ std::map<std::string, std::vector<AvailabilityCallbackEntry>> avail_queries_;
+
+ // Callback ID tracking.
+ uint32_t next_avail_query_id_;
+ std::vector<uint32_t> free_query_ids_;
+
+ CastPlatformClient* const platform_client_;
+
+ CastAppAvailabilityTracker availability_tracker_;
+
+ const ClockNowFunctionPtr clock_;
+
+ WeakPtrFactory<CastAppDiscoveryServiceImpl> weak_factory_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_CAST_APP_DISCOVERY_SERVICE_IMPL_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc
new file mode 100644
index 00000000000..863eb3472c4
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_app_discovery_service_impl_unittest.cc
@@ -0,0 +1,350 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_app_discovery_service_impl.h"
+
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_socket_error_handler.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "cast/common/public/service_info.h"
+#include "cast/sender/testing/test_helpers.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+using ::cast::channel::CastMessage;
+
+using ::testing::_;
+
+class CastAppDiscoveryServiceImplTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ socket_id_ = fake_cast_socket_pair_.socket->socket_id();
+ router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.socket));
+
+ receiver_.v4_address = fake_cast_socket_pair_.remote_endpoint.address;
+ receiver_.port = fake_cast_socket_pair_.remote_endpoint.port;
+ receiver_.unique_id = "deviceId1";
+ receiver_.friendly_name = "Some Name";
+ }
+
+ protected:
+ CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; }
+ MockCastSocketClient& peer_client() {
+ return fake_cast_socket_pair_.mock_peer_client;
+ }
+
+ void AddOrUpdateReceiver(const ServiceInfo& receiver, int32_t socket_id) {
+ platform_client_.AddOrUpdateReceiver(receiver, socket_id);
+ app_discovery_service_.AddOrUpdateReceiver(receiver);
+ }
+
+ CastAppDiscoveryService::Subscription StartObservingAvailability(
+ const CastMediaSource& source,
+ std::vector<ServiceInfo>* save_receivers) {
+ return app_discovery_service_.StartObservingAvailability(
+ source, [save_receivers](const CastMediaSource& source,
+ const std::vector<ServiceInfo>& receivers) {
+ *save_receivers = receivers;
+ });
+ }
+
+ CastAppDiscoveryService::Subscription StartSourceA1Query(
+ std::vector<ServiceInfo>* receivers,
+ int* request_id,
+ std::string* sender_id) {
+ auto subscription = StartObservingAvailability(source_a_1_, receivers);
+
+ // Adding a receiver after app registered causes app availability request to
+ // be sent.
+ *request_id = -1;
+ *sender_id = "";
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([request_id, sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", request_id, sender_id);
+ });
+
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ return subscription;
+ }
+
+ FakeCastSocketPair fake_cast_socket_pair_;
+ int32_t socket_id_;
+ MockSocketErrorHandler mock_error_handler_;
+ VirtualConnectionManager manager_;
+ VirtualConnectionRouter router_{&manager_};
+ FakeClock clock_{Clock::now()};
+ FakeTaskRunner task_runner_{&clock_};
+ CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now,
+ &task_runner_};
+ CastAppDiscoveryServiceImpl app_discovery_service_{&platform_client_,
+ &FakeClock::now};
+
+ CastMediaSource source_a_1_{"cast:AAA?clientId=1", {"AAA"}};
+ CastMediaSource source_a_2_{"cast:AAA?clientId=2", {"AAA"}};
+ CastMediaSource source_b_1_{"cast:BBB?clientId=1", {"BBB"}};
+
+ ServiceInfo receiver_;
+};
+
+TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailability) {
+ std::vector<ServiceInfo> receivers1;
+ int request_id;
+ std::string sender_id;
+ auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id);
+
+ // Same app ID should not trigger another request.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ std::vector<ServiceInfo> receivers2;
+ auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2);
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ ASSERT_EQ(receivers1.size(), 1u);
+ ASSERT_EQ(receivers2.size(), 1u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+ EXPECT_EQ(receivers2[0].unique_id, "deviceId1");
+
+ // No more updates for |source_a_1_| (i.e. |receivers1|).
+ subscription1.Reset();
+ platform_client_.RemoveReceiver(receiver_);
+ app_discovery_service_.RemoveReceiver(receiver_);
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers2.size(), 0u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest, ReAddAvailQueryUsesCachedValue) {
+ std::vector<ServiceInfo> receivers1;
+ int request_id;
+ std::string sender_id;
+ auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id);
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+
+ subscription1.Reset();
+ receivers1.clear();
+
+ // Request not re-sent; cached kAvailable value is used.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ subscription1 = StartObservingAvailability(source_a_1_, &receivers1);
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest, AvailQueryUpdatedOnReceiverUpdate) {
+ std::vector<ServiceInfo> receivers1;
+ int request_id;
+ std::string sender_id;
+ auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id);
+
+ // Result set now includes |receiver_|.
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+
+ // Updating |receiver_| causes |source_a_1_| query to be updated, but it's too
+ // soon for a new message to be sent.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ receiver_.friendly_name = "New Name";
+
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers1[0].friendly_name, "New Name");
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest, Refresh) {
+ std::vector<ServiceInfo> receivers1;
+ auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1);
+ std::vector<ServiceInfo> receivers2;
+ auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2);
+
+ // Adding a receiver after app registered causes two separate app availability
+ // requests to be sent.
+ int request_idA = -1;
+ int request_idB = -1;
+ std::string sender_id = "";
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .Times(2)
+ .WillRepeatedly([&request_idA, &request_idB, &sender_id](
+ CastSocket*, CastMessage message) {
+ std::string app_id;
+ int request_id = -1;
+ VerifyAppAvailabilityRequest(message, &app_id, &request_id, &sender_id);
+ if (app_id == "AAA") {
+ EXPECT_EQ(request_idA, -1);
+ request_idA = request_id;
+ } else if (app_id == "BBB") {
+ EXPECT_EQ(request_idB, -1);
+ request_idB = request_id;
+ } else {
+ EXPECT_TRUE(false);
+ }
+ });
+
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_idA, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ availability_response =
+ CreateAppUnavailableResponseChecked(request_idB, sender_id, "BBB");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ ASSERT_EQ(receivers1.size(), 1u);
+ ASSERT_EQ(receivers2.size(), 0u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+
+ // Not enough time has passed for a refresh.
+ clock_.Advance(std::chrono::seconds(30));
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ app_discovery_service_.Refresh();
+
+ // Refresh will now query again for unavailable app IDs.
+ clock_.Advance(std::chrono::minutes(2));
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_idB, &sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "BBB", &request_idB, &sender_id);
+ });
+ app_discovery_service_.Refresh();
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest,
+ StartObservingAvailabilityAfterReceiverAdded) {
+ // No registered apps.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ // Registering apps immediately sends requests to |receiver_|.
+ int request_idA = -1;
+ std::string sender_id = "";
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_idA, &sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", &request_idA, &sender_id);
+ });
+ std::vector<ServiceInfo> receivers1;
+ auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1);
+
+ int request_idB = -1;
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_idB, &sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "BBB", &request_idB, &sender_id);
+ });
+ std::vector<ServiceInfo> receivers2;
+ auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2);
+
+ // Add a new receiver with a corresponding socket.
+ FakeCastSocketPair fake_sockets2({{192, 168, 1, 17}, 2345},
+ {{192, 168, 1, 19}, 2345});
+ CastSocket* socket2 = fake_sockets2.socket.get();
+ router_.TakeSocket(&mock_error_handler_, std::move(fake_sockets2.socket));
+ ServiceInfo receiver2;
+ receiver2.unique_id = "deviceId2";
+ receiver2.v4_address = fake_sockets2.remote_endpoint.address;
+ receiver2.port = fake_sockets2.remote_endpoint.port;
+
+ // Adding new receiver causes availability requests for both apps to be sent
+ // to the new receiver.
+ request_idA = -1;
+ request_idB = -1;
+ EXPECT_CALL(fake_sockets2.mock_peer_client, OnMessage(_, _))
+ .Times(2)
+ .WillRepeatedly([&request_idA, &request_idB, &sender_id](
+ CastSocket*, CastMessage message) {
+ std::string app_id;
+ int request_id = -1;
+ VerifyAppAvailabilityRequest(message, &app_id, &request_id, &sender_id);
+ if (app_id == "AAA") {
+ EXPECT_EQ(request_idA, -1);
+ request_idA = request_id;
+ } else if (app_id == "BBB") {
+ EXPECT_EQ(request_idB, -1);
+ request_idB = request_id;
+ } else {
+ EXPECT_TRUE(false);
+ }
+ });
+
+ AddOrUpdateReceiver(receiver2, socket2->socket_id());
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailabilityCachedValue) {
+ std::vector<ServiceInfo> receivers1;
+ int request_id;
+ std::string sender_id;
+ auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id);
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ ASSERT_EQ(receivers1.size(), 1u);
+ EXPECT_EQ(receivers1[0].unique_id, "deviceId1");
+
+ // Same app ID should not trigger another request, but it should return
+ // cached value.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ std::vector<ServiceInfo> receivers2;
+ auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2);
+ ASSERT_EQ(receivers2.size(), 1u);
+ EXPECT_EQ(receivers2[0].unique_id, "deviceId1");
+}
+
+TEST_F(CastAppDiscoveryServiceImplTest, AvailabilityUnknownOrUnavailable) {
+ std::vector<ServiceInfo> receivers1;
+ int request_id;
+ std::string sender_id;
+ auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id);
+
+ // The request will timeout resulting in unknown app availability.
+ clock_.Advance(std::chrono::seconds(10));
+ task_runner_.RunTasksUntilIdle();
+ EXPECT_EQ(receivers1.size(), 0u);
+
+ // Receiver updated together with unknown app availability will cause a
+ // request to be sent again.
+ request_id = -1;
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_id, &sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id);
+ });
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ CastMessage availability_response =
+ CreateAppUnavailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+
+ // Known availability so no request sent.
+ EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0);
+ AddOrUpdateReceiver(receiver_, socket_id_);
+
+ // Removing the receiver will also remove previous availability information.
+ // Next time the receiver is added, a new request will be sent.
+ platform_client_.RemoveReceiver(receiver_);
+ app_discovery_service_.RemoveReceiver(receiver_);
+
+ request_id = -1;
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_id, &sender_id](CastSocket*, CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id);
+ });
+
+ AddOrUpdateReceiver(receiver_, socket_id_);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc
new file mode 100644
index 00000000000..6e3d13da4bf
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.cc
@@ -0,0 +1,224 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_platform_client.h"
+
+#include <random>
+
+#include "absl/strings/str_cat.h"
+#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "cast/common/public/service_info.h"
+#include "util/json/json_serialization.h"
+#include "util/logging.h"
+#include "util/stringprintf.h"
+
+namespace openscreen {
+namespace cast {
+
+static constexpr std::chrono::seconds kRequestTimeout = std::chrono::seconds(5);
+
+namespace {
+
+std::string MakeRandomSenderId() {
+ static auto& rd = *new std::random_device();
+ static auto& gen = *new std::mt19937(rd());
+ static auto& dist = *new std::uniform_int_distribution<>(1, 1000000);
+ return absl::StrCat("sender-", dist(gen));
+}
+
+} // namespace
+
+CastPlatformClient::CastPlatformClient(VirtualConnectionRouter* router,
+ VirtualConnectionManager* manager,
+ ClockNowFunctionPtr clock,
+ TaskRunner* task_runner)
+ : sender_id_(MakeRandomSenderId()),
+ virtual_conn_router_(router),
+ virtual_conn_manager_(manager),
+ clock_(clock),
+ task_runner_(task_runner) {
+ OSP_DCHECK(virtual_conn_manager_);
+ OSP_DCHECK(clock_);
+ OSP_DCHECK(task_runner_);
+ virtual_conn_router_->AddHandlerForLocalId(sender_id_, this);
+}
+
+CastPlatformClient::~CastPlatformClient() {
+ virtual_conn_router_->RemoveHandlerForLocalId(sender_id_);
+
+ for (auto& pending_requests : pending_requests_by_device_id_) {
+ for (auto& avail_request : pending_requests.second.availability) {
+ avail_request.callback(avail_request.app_id,
+ AppAvailabilityResult::kUnknown);
+ }
+ }
+}
+
+absl::optional<int> CastPlatformClient::RequestAppAvailability(
+ const std::string& device_id,
+ const std::string& app_id,
+ AppAvailabilityCallback callback) {
+ auto entry = socket_id_by_device_id_.find(device_id);
+ if (entry == socket_id_by_device_id_.end()) {
+ callback(app_id, AppAvailabilityResult::kUnknown);
+ return absl::nullopt;
+ }
+ int socket_id = entry->second;
+
+ int request_id = GetNextRequestId();
+ ErrorOr<::cast::channel::CastMessage> message =
+ CreateAppAvailabilityRequest(sender_id_, request_id, app_id);
+ OSP_DCHECK(message);
+
+ PendingRequests& pending_requests = pending_requests_by_device_id_[device_id];
+ auto timeout = std::make_unique<Alarm>(clock_, task_runner_);
+ timeout->ScheduleFromNow(
+ [this, request_id]() { CancelAppAvailabilityRequest(request_id); },
+ kRequestTimeout);
+ pending_requests.availability.push_back(AvailabilityRequest{
+ request_id, app_id, std::move(timeout), std::move(callback)});
+
+ VirtualConnection virtual_conn{sender_id_, kPlatformReceiverId, socket_id};
+ if (!virtual_conn_manager_->GetConnectionData(virtual_conn)) {
+ virtual_conn_manager_->AddConnection(virtual_conn,
+ VirtualConnection::AssociatedData{});
+ }
+
+ virtual_conn_router_->SendMessage(std::move(virtual_conn),
+ std::move(message.value()));
+
+ return request_id;
+}
+
+void CastPlatformClient::AddOrUpdateReceiver(const ServiceInfo& device,
+ int socket_id) {
+ socket_id_by_device_id_[device.unique_id] = socket_id;
+}
+
+void CastPlatformClient::RemoveReceiver(const ServiceInfo& device) {
+ auto pending_requests_it =
+ pending_requests_by_device_id_.find(device.unique_id);
+ if (pending_requests_it != pending_requests_by_device_id_.end()) {
+ for (const AvailabilityRequest& availability :
+ pending_requests_it->second.availability) {
+ availability.callback(availability.app_id,
+ AppAvailabilityResult::kUnknown);
+ }
+ pending_requests_by_device_id_.erase(pending_requests_it);
+ }
+ socket_id_by_device_id_.erase(device.unique_id);
+}
+
+void CastPlatformClient::CancelRequest(int request_id) {
+ for (auto entry = pending_requests_by_device_id_.begin();
+ entry != pending_requests_by_device_id_.end(); ++entry) {
+ auto& pending_requests = entry->second;
+ auto it = std::find_if(pending_requests.availability.begin(),
+ pending_requests.availability.end(),
+ [request_id](const AvailabilityRequest& request) {
+ return request.request_id == request_id;
+ });
+ if (it != pending_requests.availability.end()) {
+ pending_requests.availability.erase(it);
+ break;
+ }
+ }
+}
+
+void CastPlatformClient::OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) {
+ if (message.payload_type() !=
+ ::cast::channel::CastMessage_PayloadType_STRING ||
+ message.namespace_() != kReceiverNamespace ||
+ message.source_id() != kPlatformReceiverId) {
+ return;
+ }
+ ErrorOr<Json::Value> dict_or_error = json::Parse(message.payload_utf8());
+ if (dict_or_error.is_error()) {
+ OSP_DVLOG << "Failed to deserialize CastMessage payload.";
+ return;
+ }
+
+ Json::Value& dict = dict_or_error.value();
+ absl::optional<int> request_id =
+ MaybeGetInt(dict, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyRequestId));
+ if (request_id) {
+ auto entry = std::find_if(
+ socket_id_by_device_id_.begin(), socket_id_by_device_id_.end(),
+ [socket](const std::pair<std::string, int>& entry) {
+ return entry.second == socket->socket_id();
+ });
+ if (entry != socket_id_by_device_id_.end()) {
+ HandleResponse(entry->first, request_id.value(), dict);
+ }
+ }
+}
+
+void CastPlatformClient::HandleResponse(const std::string& device_id,
+ int request_id,
+ const Json::Value& message) {
+ auto entry = pending_requests_by_device_id_.find(device_id);
+ if (entry == pending_requests_by_device_id_.end()) {
+ return;
+ }
+ PendingRequests& pending_requests = entry->second;
+ auto it = std::find_if(pending_requests.availability.begin(),
+ pending_requests.availability.end(),
+ [request_id](const AvailabilityRequest& request) {
+ return request.request_id == request_id;
+ });
+ if (it != pending_requests.availability.end()) {
+ // TODO(btolsch): Can all of this manual parsing/checking be cleaned up into
+ // a single parsing API along with other message handling?
+ const Json::Value* maybe_availability =
+ message.find(JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyAvailability));
+ if (maybe_availability && maybe_availability->isObject()) {
+ absl::optional<absl::string_view> result =
+ MaybeGetString(*maybe_availability, &it->app_id[0],
+ &it->app_id[0] + it->app_id.size());
+ if (result) {
+ AppAvailabilityResult availability_result =
+ AppAvailabilityResult::kUnknown;
+ if (result.value() == kMessageValueAppAvailable) {
+ availability_result = AppAvailabilityResult::kAvailable;
+ } else if (result.value() == kMessageValueAppUnavailable) {
+ availability_result = AppAvailabilityResult::kUnavailable;
+ } else {
+ OSP_DVLOG << "Invalid availability result: " << result.value();
+ }
+ it->callback(it->app_id, availability_result);
+ }
+ }
+ pending_requests.availability.erase(it);
+ }
+}
+
+void CastPlatformClient::CancelAppAvailabilityRequest(int request_id) {
+ for (auto& entry : pending_requests_by_device_id_) {
+ PendingRequests& pending_requests = entry.second;
+ auto it = std::find_if(pending_requests.availability.begin(),
+ pending_requests.availability.end(),
+ [request_id](const AvailabilityRequest& request) {
+ return request.request_id == request_id;
+ });
+ if (it != pending_requests.availability.end()) {
+ it->callback(it->app_id, AppAvailabilityResult::kUnknown);
+ pending_requests.availability.erase(it);
+ }
+ }
+}
+
+// static
+int CastPlatformClient::GetNextRequestId() {
+ return next_request_id_++;
+}
+
+// static
+int CastPlatformClient::next_request_id_ = 0;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h
new file mode 100644
index 00000000000..41ad7fc704b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client.h
@@ -0,0 +1,97 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CAST_PLATFORM_CLIENT_H_
+#define CAST_SENDER_CAST_PLATFORM_CLIENT_H_
+
+#include <functional>
+#include <map>
+#include <string>
+
+#include "absl/types/optional.h"
+#include "cast/common/channel/cast_message_handler.h"
+#include "cast/sender/channel/message_util.h"
+#include "util/alarm.h"
+#include "util/json/json_value.h"
+
+namespace openscreen {
+namespace cast {
+
+struct ServiceInfo;
+class VirtualConnectionManager;
+class VirtualConnectionRouter;
+
+// This class handles Cast messages that generally relate to the "platform", in
+// other words not a specific app currently running (e.g. app availability,
+// receiver status). These messages follow a request/response format, so each
+// request requires a corresponding response callback. These requests will also
+// timeout if there is no response after a certain amount of time (currently 5
+// seconds). The timeout callbacks will be called on the thread managed by
+// |task_runner|.
+class CastPlatformClient final : public CastMessageHandler {
+ public:
+ using AppAvailabilityCallback =
+ std::function<void(const std::string& app_id, AppAvailabilityResult)>;
+
+ CastPlatformClient(VirtualConnectionRouter* router,
+ VirtualConnectionManager* manager,
+ ClockNowFunctionPtr clock,
+ TaskRunner* task_runner);
+ ~CastPlatformClient() override;
+
+ // Requests availability information for |app_id| from the receiver identified
+ // by |device_id|. |callback| will be called exactly once with a result.
+ absl::optional<int> RequestAppAvailability(const std::string& device_id,
+ const std::string& app_id,
+ AppAvailabilityCallback callback);
+
+ // Notifies this object about general receiver connectivity or property
+ // changes.
+ void AddOrUpdateReceiver(const ServiceInfo& device, int socket_id);
+ void RemoveReceiver(const ServiceInfo& device);
+
+ void CancelRequest(int request_id);
+
+ private:
+ struct AvailabilityRequest {
+ int request_id;
+ std::string app_id;
+ std::unique_ptr<Alarm> timeout;
+ AppAvailabilityCallback callback;
+ };
+
+ struct PendingRequests {
+ std::vector<AvailabilityRequest> availability;
+ };
+
+ // CastMessageHandler overrides.
+ void OnMessage(VirtualConnectionRouter* router,
+ CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
+
+ void HandleResponse(const std::string& device_id,
+ int request_id,
+ const Json::Value& message);
+
+ void CancelAppAvailabilityRequest(int request_id);
+
+ static int GetNextRequestId();
+
+ static int next_request_id_;
+
+ const std::string sender_id_;
+ VirtualConnectionRouter* const virtual_conn_router_;
+ VirtualConnectionManager* const virtual_conn_manager_;
+ std::map<std::string /* device_id */, int> socket_id_by_device_id_;
+ std::map<std::string /* device_id */, PendingRequests>
+ pending_requests_by_device_id_;
+
+ const ClockNowFunctionPtr clock_;
+ TaskRunner* const task_runner_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_CAST_PLATFORM_CLIENT_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc
new file mode 100644
index 00000000000..44e99660b06
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/cast_platform_client_unittest.cc
@@ -0,0 +1,110 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/cast_platform_client.h"
+
+#include "cast/common/channel/testing/fake_cast_socket.h"
+#include "cast/common/channel/testing/mock_socket_error_handler.h"
+#include "cast/common/channel/virtual_connection_manager.h"
+#include "cast/common/channel/virtual_connection_router.h"
+#include "cast/common/public/service_info.h"
+#include "cast/sender/testing/test_helpers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "util/json/json_serialization.h"
+#include "util/json/json_value.h"
+
+namespace openscreen {
+namespace cast {
+
+using ::cast::channel::CastMessage;
+
+using ::testing::_;
+
+class CastPlatformClientTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ socket_ = fake_cast_socket_pair_.socket.get();
+ router_.TakeSocket(&mock_error_handler_,
+ std::move(fake_cast_socket_pair_.socket));
+
+ receiver_.v4_address = IPAddress{192, 168, 0, 17};
+ receiver_.port = 4434;
+ receiver_.unique_id = "deviceId1";
+ platform_client_.AddOrUpdateReceiver(receiver_, socket_->socket_id());
+ }
+
+ protected:
+ CastSocket& peer_socket() { return *fake_cast_socket_pair_.peer_socket; }
+ MockCastSocketClient& peer_client() {
+ return fake_cast_socket_pair_.mock_peer_client;
+ }
+
+ FakeCastSocketPair fake_cast_socket_pair_;
+ CastSocket* socket_ = nullptr;
+ MockSocketErrorHandler mock_error_handler_;
+ VirtualConnectionManager manager_;
+ VirtualConnectionRouter router_{&manager_};
+ FakeClock clock_{Clock::now()};
+ FakeTaskRunner task_runner_{&clock_};
+ CastPlatformClient platform_client_{&router_, &manager_, &FakeClock::now,
+ &task_runner_};
+ ServiceInfo receiver_;
+};
+
+TEST_F(CastPlatformClientTest, AppAvailability) {
+ int request_id = -1;
+ std::string sender_id;
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_id, &sender_id](CastSocket* socket,
+ CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id);
+ });
+ bool ran = false;
+ platform_client_.RequestAppAvailability(
+ "deviceId1", "AAA",
+ [&ran](const std::string& app_id, AppAvailabilityResult availability) {
+ EXPECT_EQ("AAA", app_id);
+ EXPECT_EQ(availability, AppAvailabilityResult::kAvailable);
+ ran = true;
+ });
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ EXPECT_TRUE(ran);
+
+ // NOTE: Callback should only fire once, so it should not fire again here.
+ ran = false;
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+ EXPECT_FALSE(ran);
+}
+
+TEST_F(CastPlatformClientTest, CancelRequest) {
+ int request_id = -1;
+ std::string sender_id;
+ EXPECT_CALL(peer_client(), OnMessage(_, _))
+ .WillOnce([&request_id, &sender_id](CastSocket* socket,
+ CastMessage message) {
+ VerifyAppAvailabilityRequest(message, "AAA", &request_id, &sender_id);
+ });
+ absl::optional<int> maybe_request_id =
+ platform_client_.RequestAppAvailability(
+ "deviceId1", "AAA",
+ [](const std::string& app_id, AppAvailabilityResult availability) {
+ EXPECT_TRUE(false);
+ });
+ ASSERT_TRUE(maybe_request_id);
+ int local_request_id = maybe_request_id.value();
+ platform_client_.CancelRequest(local_request_id);
+
+ CastMessage availability_response =
+ CreateAppAvailableResponseChecked(request_id, sender_id, "AAA");
+ EXPECT_TRUE(peer_socket().SendMessage(availability_response).ok());
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc
index 9f706259d0d..fc103e9bc90 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.cc
@@ -6,17 +6,24 @@
#include <openssl/rand.h>
-#include <vector>
+#include <algorithm>
#include "cast/common/certificate/cast_cert_validator.h"
#include "cast/common/certificate/cast_cert_validator_internal.h"
#include "cast/common/certificate/cast_crl.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "platform/api/time.h"
#include "platform/base/error.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+
+using ::cast::channel::AuthResponse;
+using ::cast::channel::CastMessage;
+using ::cast::channel::DeviceAuthMessage;
+using ::cast::channel::HashAlgorithm;
+
namespace {
#define PARSE_ERROR_PREFIX "Failed to parse auth message: "
@@ -30,40 +37,35 @@ const int kNonceSizeInBytes = 16;
// The number of hours after which a nonce is regenerated.
long kNonceExpirationTimeInHours = 24;
-using CastCertError = openscreen::Error::Code;
-
// Extracts an embedded DeviceAuthMessage payload from an auth challenge reply
// message.
-openscreen::Error ParseAuthMessage(const CastMessage& challenge_reply,
- DeviceAuthMessage* auth_message) {
- if (challenge_reply.payload_type() != CastMessage_PayloadType_BINARY) {
- return openscreen::Error(CastCertError::kCastV2WrongPayloadType,
- PARSE_ERROR_PREFIX
- "Wrong payload type in challenge reply");
+Error ParseAuthMessage(const CastMessage& challenge_reply,
+ DeviceAuthMessage* auth_message) {
+ if (challenge_reply.payload_type() !=
+ ::cast::channel::CastMessage_PayloadType_BINARY) {
+ return Error(Error::Code::kCastV2WrongPayloadType,
+ PARSE_ERROR_PREFIX "Wrong payload type in challenge reply");
}
if (!challenge_reply.has_payload_binary()) {
- return openscreen::Error(
- CastCertError::kCastV2NoPayload, PARSE_ERROR_PREFIX
- "Payload type is binary but payload_binary field not set");
+ return Error(Error::Code::kCastV2NoPayload, PARSE_ERROR_PREFIX
+ "Payload type is binary but payload_binary field not set");
}
if (!auth_message->ParseFromString(challenge_reply.payload_binary())) {
- return openscreen::Error(
- CastCertError::kCastV2PayloadParsingFailed, PARSE_ERROR_PREFIX
- "Cannot parse binary payload into DeviceAuthMessage");
+ return Error(Error::Code::kCastV2PayloadParsingFailed, PARSE_ERROR_PREFIX
+ "Cannot parse binary payload into DeviceAuthMessage");
}
if (auth_message->has_error()) {
std::stringstream ss;
ss << PARSE_ERROR_PREFIX "Auth message error: "
<< auth_message->error().error_type();
- return openscreen::Error(CastCertError::kCastV2MessageError, ss.str());
+ return Error(Error::Code::kCastV2MessageError, ss.str());
}
if (!auth_message->has_response()) {
- return openscreen::Error(CastCertError::kCastV2NoResponse,
- PARSE_ERROR_PREFIX
- "Auth message has no response field");
+ return Error(Error::Code::kCastV2NoResponse,
+ PARSE_ERROR_PREFIX "Auth message has no response field");
}
- return openscreen::Error::None();
+ return Error::None();
}
class CastNonce {
@@ -84,11 +86,11 @@ class CastNonce {
OSP_CHECK_EQ(
RAND_bytes(reinterpret_cast<uint8_t*>(&nonce_[0]), kNonceSizeInBytes),
1);
- nonce_generation_time_ = openscreen::platform::GetWallTimeSinceUnixEpoch();
+ nonce_generation_time_ = GetWallTimeSinceUnixEpoch();
}
void EnsureNonceTimely() {
- if (openscreen::platform::GetWallTimeSinceUnixEpoch() >
+ if (GetWallTimeSinceUnixEpoch() >
(nonce_generation_time_ +
std::chrono::hours(kNonceExpirationTimeInHours))) {
GenerateNonce();
@@ -101,62 +103,60 @@ class CastNonce {
std::chrono::seconds nonce_generation_time_;
};
-// Maps CastCertError from certificate verification to openscreen::Error.
+// Maps Error::Code from certificate verification to Error.
// If crl_required is set to false, all revocation related errors are ignored.
-openscreen::Error MapToOpenscreenError(CastCertError error, bool crl_required) {
+Error MapToOpenscreenError(Error::Code error, bool crl_required) {
switch (error) {
- case CastCertError::kErrCertsMissing:
- return openscreen::Error(CastCertError::kCastV2PeerCertEmpty,
- "Failed to locate certificates.");
- case CastCertError::kErrCertsParse:
- return openscreen::Error(CastCertError::kErrCertsParse,
- "Failed to parse certificates.");
- case CastCertError::kErrCertsDateInvalid:
- return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa,
- "Failed date validity check.");
- case CastCertError::kErrCertsVerifyGeneric:
- return openscreen::Error(
- CastCertError::kCastV2CertNotSignedByTrustedCa,
- "Failed with a generic certificate verification error.");
- case CastCertError::kErrCertsRestrictions:
- return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa,
- "Failed certificate restrictions.");
- case CastCertError::kErrCrlInvalid:
+ case Error::Code::kErrCertsMissing:
+ return Error(Error::Code::kCastV2PeerCertEmpty,
+ "Failed to locate certificates.");
+ case Error::Code::kErrCertsParse:
+ return Error(Error::Code::kErrCertsParse,
+ "Failed to parse certificates.");
+ case Error::Code::kErrCertsDateInvalid:
+ return Error(Error::Code::kCastV2CertNotSignedByTrustedCa,
+ "Failed date validity check.");
+ case Error::Code::kErrCertsVerifyGeneric:
+ return Error(Error::Code::kCastV2CertNotSignedByTrustedCa,
+ "Failed with a generic certificate verification error.");
+ case Error::Code::kErrCertsRestrictions:
+ return Error(Error::Code::kCastV2CertNotSignedByTrustedCa,
+ "Failed certificate restrictions.");
+ case Error::Code::kErrCrlInvalid:
// This error is only encountered if |crl_required| is true.
OSP_DCHECK(crl_required);
- return openscreen::Error(CastCertError::kErrCrlInvalid,
- "Failed to provide a valid CRL.");
- case CastCertError::kErrCertsRevoked:
- return openscreen::Error(CastCertError::kErrCertsRevoked,
- "Failed certificate revocation check.");
- case CastCertError::kNone:
- return openscreen::Error::None();
+ return Error(Error::Code::kErrCrlInvalid,
+ "Failed to provide a valid CRL.");
+ case Error::Code::kErrCertsRevoked:
+ return Error(Error::Code::kErrCertsRevoked,
+ "Failed certificate revocation check.");
+ case Error::Code::kNone:
+ return Error::None();
default:
- return openscreen::Error(CastCertError::kCastV2CertNotSignedByTrustedCa,
- "Failed verifying cast device certificate.");
+ return Error(Error::Code::kCastV2CertNotSignedByTrustedCa,
+ "Failed verifying cast device certificate.");
}
- return openscreen::Error::None();
+ return Error::None();
}
-openscreen::Error VerifyAndMapDigestAlgorithm(
- HashAlgorithm response_digest_algorithm,
- certificate::DigestAlgorithm* digest_algorithm,
- bool enforce_sha256_checking) {
+Error VerifyAndMapDigestAlgorithm(HashAlgorithm response_digest_algorithm,
+ DigestAlgorithm* digest_algorithm,
+ bool enforce_sha256_checking) {
switch (response_digest_algorithm) {
- case SHA1:
+ case ::cast::channel::SHA1:
if (enforce_sha256_checking) {
- return openscreen::Error(CastCertError::kCastV2DigestUnsupported,
- "Unsupported digest algorithm.");
+ return Error(Error::Code::kCastV2DigestUnsupported,
+ "Unsupported digest algorithm.");
}
- *digest_algorithm = certificate::DigestAlgorithm::kSha1;
+ *digest_algorithm = DigestAlgorithm::kSha1;
break;
- case SHA256:
- *digest_algorithm = certificate::DigestAlgorithm::kSha256;
+ case ::cast::channel::SHA256:
+ *digest_algorithm = DigestAlgorithm::kSha256;
break;
default:
- return CastCertError::kCastV2DigestUnsupported;
+ return Error::Code::kCastV2DigestUnsupported;
}
- return openscreen::Error::None();
+ return Error::None();
}
} // namespace
@@ -170,70 +170,80 @@ AuthContext::AuthContext(const std::string& nonce) : nonce_(nonce) {}
AuthContext::~AuthContext() {}
-openscreen::Error AuthContext::VerifySenderNonce(
- const std::string& nonce_response,
- bool enforce_nonce_checking) const {
+Error AuthContext::VerifySenderNonce(const std::string& nonce_response,
+ bool enforce_nonce_checking) const {
if (nonce_ != nonce_response) {
if (enforce_nonce_checking) {
- return openscreen::Error(CastCertError::kCastV2SenderNonceMismatch,
- "Sender nonce mismatched.");
+ return Error(Error::Code::kCastV2SenderNonceMismatch,
+ "Sender nonce mismatched.");
}
}
- return openscreen::Error::None();
+ return Error::None();
}
-openscreen::Error VerifyTLSCertificateValidity(
- X509* peer_cert,
- std::chrono::seconds verification_time) {
+Error VerifyTLSCertificateValidity(X509* peer_cert,
+ std::chrono::seconds verification_time) {
// Ensure the peer cert is valid and doesn't have an excessive remaining
// lifetime. Although it is not verified as an X.509 certificate, the entire
// structure is signed by the AuthResponse, so the validity field from X.509
// is repurposed as this signature's expiration.
- certificate::DateTime not_before;
- certificate::DateTime not_after;
- if (!certificate::GetCertValidTimeRange(peer_cert, &not_before, &not_after)) {
- return openscreen::Error(CastCertError::kErrCertsParse, PARSE_ERROR_PREFIX
- "Parsing validity fields failed.");
+ DateTime not_before;
+ DateTime not_after;
+ if (!GetCertValidTimeRange(peer_cert, &not_before, &not_after)) {
+ return Error(Error::Code::kErrCertsParse,
+ PARSE_ERROR_PREFIX "Parsing validity fields failed.");
}
std::chrono::seconds lifetime_limit =
verification_time +
std::chrono::hours(24 * kMaxSelfSignedCertLifetimeInDays);
- certificate::DateTime verification_time_exploded = {};
- certificate::DateTime lifetime_limit_exploded = {};
- OSP_CHECK(certificate::DateTimeFromSeconds(verification_time.count(),
- &verification_time_exploded));
- OSP_CHECK(certificate::DateTimeFromSeconds(lifetime_limit.count(),
- &lifetime_limit_exploded));
+ DateTime verification_time_exploded = {};
+ DateTime lifetime_limit_exploded = {};
+ OSP_CHECK(DateTimeFromSeconds(verification_time.count(),
+ &verification_time_exploded));
+ OSP_CHECK(
+ DateTimeFromSeconds(lifetime_limit.count(), &lifetime_limit_exploded));
if (verification_time_exploded < not_before) {
- return openscreen::Error(
- CastCertError::kCastV2TlsCertValidStartDateInFuture,
- PARSE_ERROR_PREFIX "Certificate's valid start date is in the future.");
+ return Error(Error::Code::kCastV2TlsCertValidStartDateInFuture,
+ PARSE_ERROR_PREFIX
+ "Certificate's valid start date is in the future.");
}
if (not_after < verification_time_exploded) {
- return openscreen::Error(CastCertError::kCastV2TlsCertExpired,
- PARSE_ERROR_PREFIX "Certificate has expired.");
+ return Error(Error::Code::kCastV2TlsCertExpired,
+ PARSE_ERROR_PREFIX "Certificate has expired.");
}
if (lifetime_limit_exploded < not_after) {
- return openscreen::Error(CastCertError::kCastV2TlsCertValidityPeriodTooLong,
- PARSE_ERROR_PREFIX
- "Peer cert lifetime is too long.");
+ return Error(Error::Code::kCastV2TlsCertValidityPeriodTooLong,
+ PARSE_ERROR_PREFIX "Peer cert lifetime is too long.");
}
- return openscreen::Error::None();
+ return Error::None();
}
-ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply(
+ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl(
+ const AuthResponse& response,
+ const std::vector<uint8_t>& signature_input,
+ const CRLPolicy& crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time,
+ bool enforce_sha256_checking);
+
+ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyImpl(
const CastMessage& challenge_reply,
X509* peer_cert,
- const AuthContext& auth_context) {
+ const AuthContext& auth_context,
+ const CRLPolicy& crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time) {
DeviceAuthMessage auth_message;
- openscreen::Error result = ParseAuthMessage(challenge_reply, &auth_message);
+ Error result = ParseAuthMessage(challenge_reply, &auth_message);
if (!result.ok()) {
return result;
}
- result = VerifyTLSCertificateValidity(
- peer_cert, openscreen::platform::GetWallTimeSinceUnixEpoch());
+ result = VerifyTLSCertificateValidity(peer_cert,
+ DateTimeToSeconds(verification_time));
if (!result.ok()) {
return result;
}
@@ -246,22 +256,48 @@ ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply(
return result;
}
- int len = i2d_X509(peer_cert, nullptr);
- if (len <= 0) {
- return openscreen::Error(CastCertError::kErrCertsParse,
- "Serializing cert failed.");
+ int cert_len = i2d_X509(peer_cert, nullptr);
+ if (cert_len <= 0) {
+ return Error(Error::Code::kErrCertsParse, "Serializing cert failed.");
}
- std::string peer_cert_der(len, 0);
- uint8_t* data = reinterpret_cast<uint8_t*>(&peer_cert_der[0]);
+ size_t nonce_response_size = nonce_response.size();
+ std::vector<uint8_t> nonce_plus_peer_cert_der(nonce_response_size + cert_len,
+ 0);
+ std::copy(nonce_response.begin(), nonce_response.end(),
+ &nonce_plus_peer_cert_der[0]);
+ uint8_t* data = &nonce_plus_peer_cert_der[nonce_response_size];
if (!i2d_X509(peer_cert, &data)) {
- return openscreen::Error(CastCertError::kErrCertsParse,
- "Serializing cert failed.");
+ return Error(Error::Code::kErrCertsParse, "Serializing cert failed.");
}
- size_t actual_size = data - reinterpret_cast<uint8_t*>(&peer_cert_der[0]);
- OSP_DCHECK_EQ(actual_size, peer_cert_der.size());
- peer_cert_der.resize(actual_size);
- return VerifyCredentials(response, nonce_response + peer_cert_der);
+ return VerifyCredentialsImpl(response, nonce_plus_peer_cert_der, crl_policy,
+ cast_trust_store, crl_trust_store,
+ verification_time, false);
+}
+
+ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply(
+ const CastMessage& challenge_reply,
+ X509* peer_cert,
+ const AuthContext& auth_context) {
+ DateTime now = {};
+ OSP_CHECK(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
+ CRLPolicy policy = CRLPolicy::kCrlOptional;
+ return AuthenticateChallengeReplyImpl(
+ challenge_reply, peer_cert, auth_context, policy,
+ /* cast_trust_store */ nullptr, /* crl_trust_store */ nullptr, now);
+}
+
+ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyForTest(
+ const CastMessage& challenge_reply,
+ X509* peer_cert,
+ const AuthContext& auth_context,
+ CRLPolicy crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time) {
+ return AuthenticateChallengeReplyImpl(
+ challenge_reply, peer_cert, auth_context, crl_policy, cast_trust_store,
+ crl_trust_store, verification_time);
}
// This function does the following
@@ -283,19 +319,18 @@ ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply(
// |signature_input| by |response.client_auth_certificate|'s public key.
ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl(
const AuthResponse& response,
- const std::string& signature_input,
- const certificate::CRLPolicy& crl_policy,
- certificate::TrustStore* cast_trust_store,
- certificate::TrustStore* crl_trust_store,
- const certificate::DateTime& verification_time,
+ const std::vector<uint8_t>& signature_input,
+ const CRLPolicy& crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time,
bool enforce_sha256_checking) {
if (response.signature().empty() && !signature_input.empty()) {
- return openscreen::Error(CastCertError::kCastV2SignatureEmpty,
- "Signature is empty.");
+ return Error(Error::Code::kCastV2SignatureEmpty, "Signature is empty.");
}
// Verify the certificate
- std::unique_ptr<certificate::CertVerificationContext> verification_context;
+ std::unique_ptr<CertVerificationContext> verification_context;
// Build a single vector containing the certificate chain.
std::vector<std::string> cert_chain;
@@ -305,43 +340,41 @@ ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl(
response.intermediate_certificate().end());
// Parse the CRL.
- std::unique_ptr<certificate::CastCRL> crl;
+ std::unique_ptr<CastCRL> crl;
if (!response.crl().empty()) {
- crl = certificate::ParseAndVerifyCRL(response.crl(), verification_time,
- crl_trust_store);
+ crl = ParseAndVerifyCRL(response.crl(), verification_time, crl_trust_store);
}
// Perform certificate verification.
- certificate::CastDeviceCertPolicy device_policy;
- openscreen::Error verify_result = certificate::VerifyDeviceCert(
- cert_chain, verification_time, &verification_context, &device_policy,
- crl.get(), crl_policy, cast_trust_store);
+ CastDeviceCertPolicy device_policy;
+ Error verify_result =
+ VerifyDeviceCert(cert_chain, verification_time, &verification_context,
+ &device_policy, crl.get(), crl_policy, cast_trust_store);
// Handle and report errors.
- openscreen::Error result = MapToOpenscreenError(
- verify_result.code(), crl_policy == certificate::CRLPolicy::kCrlRequired);
+ Error result = MapToOpenscreenError(verify_result.code(),
+ crl_policy == CRLPolicy::kCrlRequired);
if (!result.ok()) {
return result;
}
// The certificate is verified at this point.
- certificate::DigestAlgorithm digest_algorithm;
- openscreen::Error digest_result = VerifyAndMapDigestAlgorithm(
+ DigestAlgorithm digest_algorithm;
+ Error digest_result = VerifyAndMapDigestAlgorithm(
response.hash_algorithm(), &digest_algorithm, enforce_sha256_checking);
if (!digest_result.ok()) {
return digest_result;
}
- certificate::ConstDataSpan signature = {
+ ConstDataSpan signature = {
reinterpret_cast<const uint8_t*>(response.signature().data()),
static_cast<uint32_t>(response.signature().size())};
- certificate::ConstDataSpan siginput = {
- reinterpret_cast<const uint8_t*>(signature_input.data()),
- static_cast<uint32_t>(signature_input.size())};
+ ConstDataSpan siginput = {signature_input.data(),
+ static_cast<uint32_t>(signature_input.size())};
if (!verification_context->VerifySignatureOverData(signature, siginput,
digest_algorithm)) {
- return openscreen::Error(CastCertError::kCastV2SignedBlobsMismatch,
- "Failed verifying signature over data.");
+ return Error(Error::Code::kCastV2SignedBlobsMismatch,
+ "Failed verifying signature over data.");
}
return device_policy;
@@ -349,31 +382,29 @@ ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl(
ErrorOr<CastDeviceCertPolicy> VerifyCredentials(
const AuthResponse& response,
- const std::string& signature_input,
+ const std::vector<uint8_t>& signature_input,
bool enforce_revocation_checking,
bool enforce_sha256_checking) {
- certificate::DateTime now = {};
- OSP_CHECK(certificate::DateTimeFromSeconds(
- openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now));
- certificate::CRLPolicy policy = (enforce_revocation_checking)
- ? certificate::CRLPolicy::kCrlRequired
- : certificate::CRLPolicy::kCrlOptional;
+ DateTime now = {};
+ OSP_CHECK(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
+ CRLPolicy policy = (enforce_revocation_checking) ? CRLPolicy::kCrlRequired
+ : CRLPolicy::kCrlOptional;
return VerifyCredentialsImpl(response, signature_input, policy, nullptr,
nullptr, now, enforce_sha256_checking);
}
ErrorOr<CastDeviceCertPolicy> VerifyCredentialsForTest(
const AuthResponse& response,
- const std::string& signature_input,
- certificate::CRLPolicy crl_policy,
- certificate::TrustStore* cast_trust_store,
- certificate::TrustStore* crl_trust_store,
- const certificate::DateTime& verification_time,
+ const std::vector<uint8_t>& signature_input,
+ CRLPolicy crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time,
bool enforce_sha256_checking) {
return VerifyCredentialsImpl(response, signature_input, crl_policy,
cast_trust_store, crl_trust_store,
verification_time, enforce_sha256_checking);
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h
index 35f1d028f5f..4ca6df5e17c 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h
@@ -7,28 +7,26 @@
#include <openssl/x509.h>
+#include <chrono> // NOLINT
#include <string>
+#include <vector>
#include "cast/common/certificate/cast_cert_validator.h"
-#include "cast/common/channel/proto/cast_channel.pb.h"
#include "platform/base/error.h"
namespace cast {
-namespace certificate {
-enum class CRLPolicy;
-struct DateTime;
-struct TrustStore;
-} // namespace certificate
-} // namespace cast
-
-namespace cast {
namespace channel {
-
class AuthResponse;
class CastMessage;
+} // namespace channel
+} // namespace cast
+
+namespace openscreen {
+namespace cast {
-using openscreen::ErrorOr;
-using CastDeviceCertPolicy = certificate::CastDeviceCertPolicy;
+enum class CRLPolicy;
+struct DateTime;
+struct TrustStore;
class AuthContext {
public:
@@ -40,9 +38,8 @@ class AuthContext {
// Verifies the nonce received in the response is equivalent to the one sent.
// Returns success if |nonce_response| matches nonce_
- openscreen::Error VerifySenderNonce(
- const std::string& nonce_response,
- bool enforce_nonce_checking = false) const;
+ Error VerifySenderNonce(const std::string& nonce_response,
+ bool enforce_nonce_checking = false) const;
// The nonce challenge.
const std::string& nonce() const { return nonce_; }
@@ -55,40 +52,52 @@ class AuthContext {
// Authenticates the given |challenge_reply|:
// 1. Signature contained in the reply is valid.
-// 2. Certficate used to sign is rooted to a trusted CA.
+// 2. certificate used to sign is rooted to a trusted CA.
ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReply(
- const CastMessage& challenge_reply,
+ const ::cast::channel::CastMessage& challenge_reply,
X509* peer_cert,
const AuthContext& auth_context);
-// Performs a quick check of the TLS certificate for time validity requirements.
-openscreen::Error VerifyTLSCertificateValidity(
+// Exposed for testing only.
+//
+// Overloaded version of AuthenticateChallengeReply that allows modifying the
+// crl policy, trust stores, and verification times.
+ErrorOr<CastDeviceCertPolicy> AuthenticateChallengeReplyForTest(
+ const ::cast::channel::CastMessage& challenge_reply,
X509* peer_cert,
- std::chrono::seconds verification_time);
+ const AuthContext& auth_context,
+ CRLPolicy crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time);
+
+// Performs a quick check of the TLS certificate for time validity requirements.
+Error VerifyTLSCertificateValidity(X509* peer_cert,
+ std::chrono::seconds verification_time);
// Auth-library specific implementation of cryptographic signature verification
// routines. Verifies that |response| contains a valid signature of
// |signature_input|.
ErrorOr<CastDeviceCertPolicy> VerifyCredentials(
- const AuthResponse& response,
- const std::string& signature_input,
+ const ::cast::channel::AuthResponse& response,
+ const std::vector<uint8_t>& signature_input,
bool enforce_revocation_checking = false,
bool enforce_sha256_checking = false);
// Exposed for testing only.
//
-// Overloaded version of VerifyCredentials that allows modifying
-// the crl policy, trust stores, and verification times.
+// Overloaded version of VerifyCredentials that allows modifying the crl policy,
+// trust stores, and verification times.
ErrorOr<CastDeviceCertPolicy> VerifyCredentialsForTest(
- const AuthResponse& response,
- const std::string& signature_input,
- certificate::CRLPolicy crl_policy,
- certificate::TrustStore* cast_trust_store,
- certificate::TrustStore* crl_trust_store,
- const certificate::DateTime& verification_time,
+ const ::cast::channel::AuthResponse& response,
+ const std::vector<uint8_t>& signature_input,
+ CRLPolicy crl_policy,
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store,
+ const DateTime& verification_time,
bool enforce_sha256_checking = false);
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_SENDER_CHANNEL_CAST_AUTH_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc
index 6ea4ea6a141..0b76c18ba3d 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util_unittest.cc
@@ -9,19 +9,27 @@
#include "cast/common/certificate/cast_cert_validator.h"
#include "cast/common/certificate/cast_crl.h"
#include "cast/common/certificate/proto/test_suite.pb.h"
-#include "cast/common/certificate/test_helpers.h"
+#include "cast/common/certificate/testing/test_helpers.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "gtest/gtest.h"
#include "platform/api/time.h"
+#include "testing/util/read_file.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+
+// TODO(crbug.com/openscreen/90): Remove these after Chromium is migrated to
+// openscreen::cast
+using DeviceCertTestSuite = ::cast::certificate::DeviceCertTestSuite;
+using VerificationResult = ::cast::certificate::VerificationResult;
+using DeviceCertTest = ::cast::certificate::DeviceCertTest;
+
namespace {
-using ErrorCode = openscreen::Error::Code;
+using ::cast::channel::AuthResponse;
-bool ConvertTimeSeconds(const certificate::DateTime& time, uint64_t* seconds) {
+bool ConvertTimeSeconds(const DateTime& time, uint64_t* seconds) {
static constexpr uint64_t kDaysPerYear = 365;
static constexpr uint64_t kHoursPerDay = 24;
static constexpr uint64_t kMinutesPerHour = 60;
@@ -101,7 +109,7 @@ bool ConvertTimeSeconds(const certificate::DateTime& time, uint64_t* seconds) {
#define TEST_DATA_PREFIX OPENSCREEN_TEST_DATA_DIR "cast/common/certificate/"
-class CastAuthUtilTest : public testing::Test {
+class CastAuthUtilTest : public ::testing::Test {
public:
CastAuthUtilTest() {}
~CastAuthUtilTest() override {}
@@ -109,141 +117,151 @@ class CastAuthUtilTest : public testing::Test {
void SetUp() override {}
protected:
- static AuthResponse CreateAuthResponse(std::string* signed_data,
- HashAlgorithm digest_algorithm) {
- std::vector<std::string> chain =
- certificate::testing::ReadCertificatesFromPemFile(
- TEST_DATA_PREFIX "certificates/chromecast_gen1.pem");
+ static AuthResponse CreateAuthResponse(
+ std::vector<uint8_t>* signed_data,
+ ::cast::channel::HashAlgorithm digest_algorithm) {
+ std::vector<std::string> chain = testing::ReadCertificatesFromPemFile(
+ TEST_DATA_PREFIX "certificates/chromecast_gen1.pem");
OSP_CHECK(!chain.empty());
- certificate::testing::SignatureTestData signatures =
- certificate::testing::ReadSignatureTestData(
- TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem");
+ testing::SignatureTestData signatures = testing::ReadSignatureTestData(
+ TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem");
AuthResponse response;
response.set_client_auth_certificate(chain[0]);
- for (size_t i = 1; i < chain.size(); ++i)
+ for (size_t i = 1; i < chain.size(); ++i) {
response.add_intermediate_certificate(chain[i]);
+ }
response.set_hash_algorithm(digest_algorithm);
switch (digest_algorithm) {
- case SHA1:
+ case ::cast::channel::SHA1:
response.set_signature(
std::string(reinterpret_cast<const char*>(signatures.sha1.data),
signatures.sha1.length));
break;
- case SHA256:
+ case ::cast::channel::SHA256:
response.set_signature(
std::string(reinterpret_cast<const char*>(signatures.sha256.data),
signatures.sha256.length));
break;
}
- signed_data->assign(reinterpret_cast<const char*>(signatures.message.data),
- signatures.message.length);
+ *signed_data = std::vector<uint8_t>(
+ signatures.message.data,
+ signatures.message.data + signatures.message.length);
return response;
}
// Mangles a string by inverting the first byte.
static void MangleString(std::string* str) { (*str)[0] = ~(*str)[0]; }
+
+ // Mangles a vector by inverting the first byte.
+ static void MangleData(std::vector<uint8_t>* data) {
+ (*data)[0] = ~(*data)[0];
+ }
};
// Note on expiration: VerifyCredentials() depends on the system clock. In
// practice this shouldn't be a problem though since the certificate chain
// being verified doesn't expire until 2032.
TEST_F(CastAuthUtilTest, VerifySuccess) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
- certificate::DateTime now = {};
- ASSERT_TRUE(certificate::DateTimeFromSeconds(
- openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now));
- ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
- auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr,
- nullptr, now);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
+ DateTime now = {};
+ ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
+ ErrorOr<CastDeviceCertPolicy> result =
+ VerifyCredentialsForTest(auth_response, signed_data,
+ CRLPolicy::kCrlOptional, nullptr, nullptr, now);
EXPECT_TRUE(result);
- EXPECT_EQ(certificate::CastDeviceCertPolicy::kUnrestricted, result.value());
+ EXPECT_EQ(CastDeviceCertPolicy::kUnrestricted, result.value());
}
TEST_F(CastAuthUtilTest, VerifyBadCA) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
MangleString(auth_response.mutable_intermediate_certificate(0));
ErrorOr<CastDeviceCertPolicy> result =
VerifyCredentials(auth_response, signed_data);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kErrCertsParse, result.error().code());
+ EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyBadClientAuthCert) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
MangleString(auth_response.mutable_client_auth_certificate());
ErrorOr<CastDeviceCertPolicy> result =
VerifyCredentials(auth_response, signed_data);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kErrCertsParse, result.error().code());
+ EXPECT_EQ(Error::Code::kErrCertsParse, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyBadSignature) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
MangleString(auth_response.mutable_signature());
ErrorOr<CastDeviceCertPolicy> result =
VerifyCredentials(auth_response, signed_data);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2SignedBlobsMismatch, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyEmptySignature) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
auth_response.mutable_signature()->clear();
ErrorOr<CastDeviceCertPolicy> result =
VerifyCredentials(auth_response, signed_data);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2SignatureEmpty, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2SignatureEmpty, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyUnsupportedDigest) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA1);
- certificate::DateTime now = {};
- ASSERT_TRUE(certificate::DateTimeFromSeconds(
- openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now));
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA1);
+ DateTime now = {};
+ ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
- auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr,
- nullptr, now, true);
+ auth_response, signed_data, CRLPolicy::kCrlOptional, nullptr, nullptr,
+ now, true);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2DigestUnsupported, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2DigestUnsupported, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyBackwardsCompatibleDigest) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA1);
- certificate::DateTime now = {};
- ASSERT_TRUE(certificate::DateTimeFromSeconds(
- openscreen::platform::GetWallTimeSinceUnixEpoch().count(), &now));
- ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
- auth_response, signed_data, certificate::CRLPolicy::kCrlOptional, nullptr,
- nullptr, now);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA1);
+ DateTime now = {};
+ ASSERT_TRUE(DateTimeFromSeconds(GetWallTimeSinceUnixEpoch().count(), &now));
+ ErrorOr<CastDeviceCertPolicy> result =
+ VerifyCredentialsForTest(auth_response, signed_data,
+ CRLPolicy::kCrlOptional, nullptr, nullptr, now);
EXPECT_TRUE(result);
}
TEST_F(CastAuthUtilTest, VerifyBadPeerCert) {
- std::string signed_data;
- AuthResponse auth_response = CreateAuthResponse(&signed_data, SHA256);
- MangleString(&signed_data);
+ std::vector<uint8_t> signed_data;
+ AuthResponse auth_response =
+ CreateAuthResponse(&signed_data, ::cast::channel::SHA256);
+ MangleData(&signed_data);
ErrorOr<CastDeviceCertPolicy> result =
VerifyCredentials(auth_response, signed_data);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2SignedBlobsMismatch, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2SignedBlobsMismatch, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifySenderNonceMatch) {
AuthContext context = AuthContext::Create();
- const openscreen::Error result =
- context.VerifySenderNonce(context.nonce(), true);
+ const Error result = context.VerifySenderNonce(context.nonce(), true);
EXPECT_TRUE(result.ok());
}
@@ -254,7 +272,7 @@ TEST_F(CastAuthUtilTest, VerifySenderNonceMismatch) {
ErrorOr<CastDeviceCertPolicy> result =
context.VerifySenderNonce(received_nonce, true);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2SenderNonceMismatch, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) {
@@ -264,40 +282,36 @@ TEST_F(CastAuthUtilTest, VerifySenderNonceMissing) {
ErrorOr<CastDeviceCertPolicy> result =
context.VerifySenderNonce(received_nonce, true);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2SenderNonceMismatch, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2SenderNonceMismatch, result.error().code());
}
TEST_F(CastAuthUtilTest, VerifyTLSCertificateSuccess) {
- std::vector<std::string> tls_cert_der =
- certificate::testing::ReadCertificatesFromPemFile(
- TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
+ std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile(
+ TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
std::string& der_cert = tls_cert_der[0];
const uint8_t* data = (const uint8_t*)der_cert.data();
X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
- certificate::DateTime not_before;
- certificate::DateTime not_after;
- ASSERT_TRUE(
- certificate::GetCertValidTimeRange(tls_cert, &not_before, &not_after));
+ DateTime not_before;
+ DateTime not_after;
+ ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
uint64_t x;
ASSERT_TRUE(ConvertTimeSeconds(not_before, &x));
std::chrono::seconds s(x);
- const openscreen::Error result = VerifyTLSCertificateValidity(tls_cert, s);
+ const Error result = VerifyTLSCertificateValidity(tls_cert, s);
EXPECT_TRUE(result.ok());
X509_free(tls_cert);
}
TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) {
- std::vector<std::string> tls_cert_der =
- certificate::testing::ReadCertificatesFromPemFile(
- TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
+ std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile(
+ TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
std::string& der_cert = tls_cert_der[0];
const uint8_t* data = (const uint8_t*)der_cert.data();
X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
- certificate::DateTime not_before;
- certificate::DateTime not_after;
- ASSERT_TRUE(
- certificate::GetCertValidTimeRange(tls_cert, &not_before, &not_after));
+ DateTime not_before;
+ DateTime not_after;
+ ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
uint64_t x;
ASSERT_TRUE(ConvertTimeSeconds(not_before, &x));
std::chrono::seconds s(x - 1);
@@ -305,22 +319,20 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooEarly) {
ErrorOr<CastDeviceCertPolicy> result =
VerifyTLSCertificateValidity(tls_cert, s);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2TlsCertValidStartDateInFuture,
+ EXPECT_EQ(Error::Code::kCastV2TlsCertValidStartDateInFuture,
result.error().code());
X509_free(tls_cert);
}
TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) {
- std::vector<std::string> tls_cert_der =
- certificate::testing::ReadCertificatesFromPemFile(
- TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
+ std::vector<std::string> tls_cert_der = testing::ReadCertificatesFromPemFile(
+ TEST_DATA_PREFIX "certificates/test_tls_cert.pem");
std::string& der_cert = tls_cert_der[0];
const uint8_t* data = (const uint8_t*)der_cert.data();
X509* tls_cert = d2i_X509(nullptr, &data, der_cert.size());
- certificate::DateTime not_before;
- certificate::DateTime not_after;
- ASSERT_TRUE(
- certificate::GetCertValidTimeRange(tls_cert, &not_before, &not_after));
+ DateTime not_before;
+ DateTime not_after;
+ ASSERT_TRUE(GetCertValidTimeRange(tls_cert, &not_before, &not_after));
uint64_t x;
ASSERT_TRUE(ConvertTimeSeconds(not_after, &x));
std::chrono::seconds s(x + 2);
@@ -328,7 +340,7 @@ TEST_F(CastAuthUtilTest, VerifyTLSCertificateTooLate) {
ErrorOr<CastDeviceCertPolicy> result =
VerifyTLSCertificateValidity(tls_cert, s);
EXPECT_FALSE(result);
- EXPECT_EQ(ErrorCode::kCastV2TlsCertExpired, result.error().code());
+ EXPECT_EQ(Error::Code::kCastV2TlsCertExpired, result.error().code());
X509_free(tls_cert);
}
@@ -346,39 +358,40 @@ enum TestStepResult {
ErrorOr<CastDeviceCertPolicy> TestVerifyRevocation(
const std::vector<std::string>& certificate_chain,
const std::string& crl_bundle,
- const certificate::DateTime& verification_time,
+ const DateTime& verification_time,
bool crl_required,
- certificate::TrustStore* cast_trust_store,
- certificate::TrustStore* crl_trust_store) {
+ TrustStore* cast_trust_store,
+ TrustStore* crl_trust_store) {
AuthResponse response;
if (certificate_chain.size() > 0) {
response.set_client_auth_certificate(certificate_chain[0]);
- for (size_t i = 1; i < certificate_chain.size(); ++i)
+ for (size_t i = 1; i < certificate_chain.size(); ++i) {
response.add_intermediate_certificate(certificate_chain[i]);
+ }
}
response.set_crl(crl_bundle);
- certificate::CRLPolicy crl_policy = certificate::CRLPolicy::kCrlRequired;
+ CRLPolicy crl_policy = CRLPolicy::kCrlRequired;
if (!crl_required && crl_bundle.empty())
- crl_policy = certificate::CRLPolicy::kCrlOptional;
- ErrorOr<CastDeviceCertPolicy> result =
- VerifyCredentialsForTest(response, "", crl_policy, cast_trust_store,
- crl_trust_store, verification_time);
+ crl_policy = CRLPolicy::kCrlOptional;
+ ErrorOr<CastDeviceCertPolicy> result = VerifyCredentialsForTest(
+ response, std::vector<uint8_t>(), crl_policy, cast_trust_store,
+ crl_trust_store, verification_time);
// This test doesn't set the signature so it will just fail there.
EXPECT_FALSE(result);
return result;
}
// Runs a single test case.
-bool RunTest(const certificate::DeviceCertTest& test_case) {
- std::unique_ptr<certificate::TrustStore> crl_trust_store;
- std::unique_ptr<certificate::TrustStore> cast_trust_store;
+bool RunTest(const DeviceCertTest& test_case) {
+ std::unique_ptr<TrustStore> crl_trust_store;
+ std::unique_ptr<TrustStore> cast_trust_store;
if (test_case.use_test_trust_anchors()) {
- crl_trust_store = certificate::testing::CreateTrustStoreFromPemFile(
+ crl_trust_store = testing::CreateTrustStoreFromPemFile(
TEST_DATA_PREFIX "certificates/cast_crl_test_root_ca.pem");
- cast_trust_store = certificate::testing::CreateTrustStoreFromPemFile(
+ cast_trust_store = testing::CreateTrustStoreFromPemFile(
TEST_DATA_PREFIX "certificates/cast_test_root_ca.pem");
EXPECT_FALSE(crl_trust_store->certs.empty());
@@ -391,51 +404,49 @@ bool RunTest(const certificate::DeviceCertTest& test_case) {
}
// CastAuthUtil verifies the CRL at the same time as the certificate.
- certificate::DateTime verification_time;
+ DateTime verification_time;
uint64_t cert_verify_time = test_case.cert_verification_time_seconds();
if (!cert_verify_time) {
cert_verify_time = test_case.crl_verification_time_seconds();
}
- OSP_DCHECK(
- certificate::DateTimeFromSeconds(cert_verify_time, &verification_time));
+ OSP_DCHECK(DateTimeFromSeconds(cert_verify_time, &verification_time));
std::string crl_bundle = test_case.crl_bundle();
- ErrorOr<CastDeviceCertPolicy> result(
- certificate::CastDeviceCertPolicy::kUnrestricted);
+ ErrorOr<CastDeviceCertPolicy> result(CastDeviceCertPolicy::kUnrestricted);
switch (test_case.expected_result()) {
- case certificate::PATH_VERIFICATION_FAILED:
+ case ::cast::certificate::PATH_VERIFICATION_FAILED:
result = TestVerifyRevocation(
certificate_chain, crl_bundle, verification_time, false,
cast_trust_store.get(), crl_trust_store.get());
EXPECT_EQ(result.error().code(),
- ErrorCode::kCastV2CertNotSignedByTrustedCa);
+ Error::Code::kCastV2CertNotSignedByTrustedCa);
return result.error().code() ==
- ErrorCode::kCastV2CertNotSignedByTrustedCa;
- case certificate::CRL_VERIFICATION_FAILED:
+ Error::Code::kCastV2CertNotSignedByTrustedCa;
+ case ::cast::certificate::CRL_VERIFICATION_FAILED:
// Fall-through intended.
- case certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL:
+ case ::cast::certificate::REVOCATION_CHECK_FAILED_WITHOUT_CRL:
result = TestVerifyRevocation(
certificate_chain, crl_bundle, verification_time, true,
cast_trust_store.get(), crl_trust_store.get());
- EXPECT_EQ(result.error().code(), ErrorCode::kErrCrlInvalid);
- return result.error().code() == ErrorCode::kErrCrlInvalid;
- case certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION:
+ EXPECT_EQ(result.error().code(), Error::Code::kErrCrlInvalid);
+ return result.error().code() == Error::Code::kErrCrlInvalid;
+ case ::cast::certificate::CRL_EXPIRED_AFTER_INITIAL_VERIFICATION:
// By-pass this test because CRL is always verified at the time the
// certificate is verified.
return true;
- case certificate::REVOCATION_CHECK_FAILED:
+ case ::cast::certificate::REVOCATION_CHECK_FAILED:
result = TestVerifyRevocation(
certificate_chain, crl_bundle, verification_time, true,
cast_trust_store.get(), crl_trust_store.get());
- EXPECT_EQ(result.error().code(), ErrorCode::kErrCertsRevoked);
- return result.error().code() == ErrorCode::kErrCertsRevoked;
- case certificate::SUCCESS:
+ EXPECT_EQ(result.error().code(), Error::Code::kErrCertsRevoked);
+ return result.error().code() == Error::Code::kErrCertsRevoked;
+ case ::cast::certificate::SUCCESS:
result = TestVerifyRevocation(
certificate_chain, crl_bundle, verification_time, false,
cast_trust_store.get(), crl_trust_store.get());
- EXPECT_EQ(result.error().code(), ErrorCode::kCastV2SignedBlobsMismatch);
- return result.error().code() == ErrorCode::kCastV2SignedBlobsMismatch;
- case certificate::UNSPECIFIED:
+ EXPECT_EQ(result.error().code(), Error::Code::kCastV2SignedBlobsMismatch);
+ return result.error().code() == Error::Code::kCastV2SignedBlobsMismatch;
+ case ::cast::certificate::UNSPECIFIED:
return false;
}
return false;
@@ -446,9 +457,8 @@ bool RunTest(const certificate::DeviceCertTest& test_case) {
// To see the description of the test, execute the test.
// These tests are generated by a test generator in google3.
void RunTestSuite(const std::string& test_suite_file_name) {
- std::string testsuite_raw =
- certificate::testing::ReadEntireFileToString(test_suite_file_name);
- certificate::DeviceCertTestSuite test_suite;
+ std::string testsuite_raw = ReadEntireFileToString(test_suite_file_name);
+ DeviceCertTestSuite test_suite;
EXPECT_TRUE(test_suite.ParseFromString(testsuite_raw));
uint16_t successes = 0;
@@ -467,5 +477,5 @@ TEST_F(CastAuthUtilTest, CRLTestSuite) {
}
} // namespace
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc
index ab3ed5d807d..6d96b730197 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.cc
@@ -5,9 +5,14 @@
#include "cast/sender/channel/message_util.h"
#include "cast/sender/channel/cast_auth_util.h"
+#include "util/json/json_serialization.h"
+namespace openscreen {
namespace cast {
-namespace channel {
+
+using ::cast::channel::AuthChallenge;
+using ::cast::channel::CastMessage;
+using ::cast::channel::DeviceAuthMessage;
CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
CastMessage message;
@@ -15,7 +20,7 @@ CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
AuthChallenge* challenge = auth_message.mutable_challenge();
challenge->set_sender_nonce(auth_context.nonce());
- challenge->set_hash_algorithm(SHA256);
+ challenge->set_hash_algorithm(::cast::channel::SHA256);
std::string auth_message_string;
auth_message.SerializeToString(&auth_message_string);
@@ -24,11 +29,39 @@ CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
message.set_source_id(kPlatformSenderId);
message.set_destination_id(kPlatformReceiverId);
message.set_namespace_(kAuthNamespace);
- message.set_payload_type(CastMessage_PayloadType_BINARY);
+ message.set_payload_type(::cast::channel::CastMessage_PayloadType_BINARY);
message.set_payload_binary(auth_message_string);
return message;
}
-} // namespace channel
+ErrorOr<CastMessage> CreateAppAvailabilityRequest(const std::string& sender_id,
+ int request_id,
+ const std::string& app_id) {
+ Json::Value dict(Json::ValueType::objectValue);
+ dict[kMessageKeyType] = Json::Value(
+ CastMessageTypeToString(CastMessageType::kGetAppAvailability));
+ Json::Value app_id_value(Json::ValueType::arrayValue);
+ app_id_value.append(Json::Value(app_id));
+ dict[kMessageKeyAppId] = std::move(app_id_value);
+ dict[kMessageKeyRequestId] = Json::Value(request_id);
+
+ CastMessage message;
+ message.set_payload_type(::cast::channel::CastMessage_PayloadType_STRING);
+ ErrorOr<std::string> serialized = json::Stringify(dict);
+ if (serialized.is_error()) {
+ return serialized.error();
+ }
+ message.set_payload_utf8(serialized.value());
+
+ message.set_protocol_version(
+ ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
+ message.set_source_id(sender_id);
+ message.set_destination_id(kPlatformReceiverId);
+ message.set_namespace_(kReceiverNamespace);
+
+ return message;
+}
+
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h
index e2da0cd8215..1a8c7717e32 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/message_util.h
@@ -7,15 +7,23 @@
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
+#include "platform/base/error.h"
+namespace openscreen {
namespace cast {
-namespace channel {
class AuthContext;
-CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context);
+::cast::channel::CastMessage CreateAuthChallengeMessage(
+ const AuthContext& auth_context);
+
+// |request_id| must be unique for |sender_id|.
+ErrorOr<::cast::channel::CastMessage> CreateAppAvailabilityRequest(
+ const std::string& sender_id,
+ int request_id,
+ const std::string& app_id);
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc
index 3e75c9776ee..bf89de88fae 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.cc
@@ -5,34 +5,47 @@
#include "cast/sender/channel/sender_socket_factory.h"
#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/sender/channel/message_util.h"
#include "platform/base/tls_connect_options.h"
#include "util/crypto/certificate_utils.h"
+#include "util/logging.h"
-namespace cast {
-namespace channel {
+using ::cast::channel::CastMessage;
-using openscreen::platform::TlsConnectOptions;
+namespace openscreen {
+namespace cast {
bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
- uint32_t b) {
+ int b) {
return a && a->socket->socket_id() < b;
}
-bool operator<(uint32_t a,
+bool operator<(int a,
const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
return b && a < b->socket->socket_id();
}
-SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) {
+SenderSocketFactory::SenderSocketFactory(Client* client,
+ TaskRunner* task_runner)
+ : client_(client), task_runner_(task_runner) {
OSP_DCHECK(client);
+ OSP_DCHECK(task_runner);
}
-SenderSocketFactory::~SenderSocketFactory() = default;
+SenderSocketFactory::~SenderSocketFactory() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+}
+
+void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
+ OSP_DCHECK(factory);
+ factory_ = factory;
+}
void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
DeviceMediaPolicy media_policy,
CastSocket::Client* client) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(client);
auto it = FindPendingConnection(endpoint);
if (it == pending_connections_.end()) {
@@ -64,15 +77,15 @@ void SenderSocketFactory::OnConnected(
CastSocket::Client* client = it->client;
pending_connections_.erase(it);
- ErrorOr<bssl::UniquePtr<X509>> peer_cert = openscreen::ImportCertificate(
- der_x509_peer_cert.data(), der_x509_peer_cert.size());
+ ErrorOr<bssl::UniquePtr<X509>> peer_cert =
+ ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
if (!peer_cert) {
client_->OnError(this, endpoint, peer_cert.error());
return;
}
- auto socket = std::make_unique<CastSocket>(std::move(connection), this,
- GetNextSocketId());
+ auto socket =
+ MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
pending_auth_.emplace_back(
new PendingAuth{endpoint, media_policy, std::move(socket), client,
AuthContext::Create(), std::move(peer_cert.value())});
@@ -149,22 +162,27 @@ void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
}
ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
- message, (*it)->peer_cert.get(), (*it)->auth_context);
+ message, pending->peer_cert.get(), pending->auth_context);
if (policy_or_error.is_error()) {
+ OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
+ << " with error: " << policy_or_error.error();
client_->OnError(this, pending->endpoint, policy_or_error.error());
return;
}
if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
- pending->media_policy != DeviceMediaPolicy::kAudioOnly) {
+ pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
client_->OnError(this, pending->endpoint,
Error::Code::kCastV2ChannelPolicyMismatch);
return;
}
+ pending->socket->set_audio_only(policy_or_error.value() ==
+ CastDeviceCertPolicy::kAudioOnly);
pending->socket->SetClient(pending->client);
- client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
+ client_->OnConnected(this, pending->endpoint,
+ std::unique_ptr<CastSocket>(pending->socket.release()));
}
-} // namespace channel
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h
index 63998674bf9..7c788536187 100644
--- a/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h
+++ b/chromium/third_party/openscreen/src/cast/sender/channel/sender_socket_factory.h
@@ -12,19 +12,15 @@
#include <vector>
#include "cast/common/channel/cast_socket.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/sender/channel/cast_auth_util.h"
+#include "platform/api/task_runner.h"
#include "platform/api/tls_connection_factory.h"
#include "platform/base/ip_address.h"
-#include "util/logging.h"
+#include "util/serial_delete_ptr.h"
+namespace openscreen {
namespace cast {
-namespace channel {
-
-using openscreen::Error;
-using openscreen::IPEndpoint;
-using openscreen::IPEndpointComparator;
-using openscreen::platform::TlsConnection;
-using openscreen::platform::TlsConnectionFactory;
class SenderSocketFactory final : public TlsConnectionFactory::Client,
public CastSocket::Client {
@@ -40,19 +36,25 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client,
};
enum class DeviceMediaPolicy {
+ kNone = 0,
kAudioOnly,
kIncludesVideo,
};
- // |client| must outlive |this|.
- explicit SenderSocketFactory(Client* client);
+ // |client| and |task_runner| must outlive |this|.
+ SenderSocketFactory(Client* client, TaskRunner* task_runner);
~SenderSocketFactory();
- void set_factory(TlsConnectionFactory* factory) {
- OSP_DCHECK(factory);
- factory_ = factory;
- }
+ // |factory| cannot be nullptr and must outlive |this|.
+ void set_factory(TlsConnectionFactory* factory);
+ // Begins connecting to a Cast device at |endpoint|. If a successful
+ // connection is made, including device authentication, the new CastSocket
+ // will be passed to |client_|'s OnConnected method. The new CastSocket will
+ // have its client set to |client|. If any part of the connection process
+ // fails, |client_|'s OnError method is called instead. This includes if the
+ // device's media policy, as determined by authentication, is audio-only and
+ // |media_policy| is kIncludesVideo.
void Connect(const IPEndpoint& endpoint,
DeviceMediaPolicy media_policy,
CastSocket::Client* client);
@@ -78,29 +80,31 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client,
struct PendingAuth {
IPEndpoint endpoint;
DeviceMediaPolicy media_policy;
- std::unique_ptr<CastSocket> socket;
+ SerialDeletePtr<CastSocket> socket;
CastSocket::Client* client;
AuthContext auth_context;
bssl::UniquePtr<X509> peer_cert;
};
- friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b);
- friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b);
+ friend bool operator<(const std::unique_ptr<PendingAuth>& a, int b);
+ friend bool operator<(int a, const std::unique_ptr<PendingAuth>& b);
std::vector<PendingConnection>::iterator FindPendingConnection(
const IPEndpoint& endpoint);
// CastSocket::Client overrides.
void OnError(CastSocket* socket, Error error) override;
- void OnMessage(CastSocket* socket, CastMessage message) override;
+ void OnMessage(CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
Client* const client_;
+ TaskRunner* const task_runner_;
TlsConnectionFactory* factory_ = nullptr;
std::vector<PendingConnection> pending_connections_;
std::vector<std::unique_ptr<PendingAuth>> pending_auth_;
};
-} // namespace channel
} // namespace cast
+} // namespace openscreen
#endif // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/DEPS b/chromium/third_party/openscreen/src/cast/sender/public/DEPS
index 44de6584c7b..8a61a25853f 100644
--- a/chromium/third_party/openscreen/src/cast/sender/public/DEPS
+++ b/chromium/third_party/openscreen/src/cast/sender/public/DEPS
@@ -1,12 +1,9 @@
# -*- Mode: Python; -*-
include_rules = [
- # By default, openscreen implementation libraries should not be exposed
- # through public APIs.
- '-base',
- '-platform',
-
# Dependencies on the implementation are not allowed in public/.
'-cast/sender',
- '+cast/sender/public'
+ '+cast/sender/public',
+ '-cast/common',
+ '+cast/common/public',
]
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/README.md b/chromium/third_party/openscreen/src/cast/sender/public/README.md
index ace86163571..b670d1100a6 100644
--- a/chromium/third_party/openscreen/src/cast/sender/public/README.md
+++ b/chromium/third_party/openscreen/src/cast/sender/public/README.md
@@ -1,4 +1,4 @@
-# cast/receiver/public
+# cast/sender/public
This module contains an implementation of the Cast "sender", i.e. the client
that discovers Cast devices on the LAN and launches apps on them.
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc
new file mode 100644
index 00000000000..c561b3865b7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.cc
@@ -0,0 +1,48 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/public/cast_app_discovery_service.h"
+
+namespace openscreen {
+namespace cast {
+
+CastAppDiscoveryService::Subscription::Subscription(
+ CastAppDiscoveryService* discovery_service,
+ uint32_t id)
+ : discovery_service_(discovery_service), id_(id) {}
+
+CastAppDiscoveryService::Subscription::Subscription(Subscription&& other)
+ : discovery_service_(other.discovery_service_), id_(other.id_) {
+ other.discovery_service_ = nullptr;
+}
+
+CastAppDiscoveryService::Subscription::~Subscription() {
+ Reset();
+}
+
+CastAppDiscoveryService::Subscription& CastAppDiscoveryService::Subscription::
+operator=(Subscription other) {
+ Swap(other);
+ return *this;
+}
+
+void CastAppDiscoveryService::Subscription::Reset() {
+ if (discovery_service_) {
+ discovery_service_->RemoveAvailabilityCallback(id_);
+ }
+ discovery_service_ = nullptr;
+}
+
+void CastAppDiscoveryService::Subscription::Swap(Subscription& other) {
+ CastAppDiscoveryService* service = other.discovery_service_;
+ other.discovery_service_ = discovery_service_;
+ discovery_service_ = service;
+
+ uint32_t id = other.id_;
+ other.id_ = id_;
+ id_ = id;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h
new file mode 100644
index 00000000000..ccb586ff816
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_app_discovery_service.h
@@ -0,0 +1,75 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_
+#define CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_
+
+#include <vector>
+
+#include "cast/common/public/service_info.h"
+
+namespace openscreen {
+namespace cast {
+
+class CastMediaSource;
+
+// Interface for app discovery for Cast devices.
+class CastAppDiscoveryService {
+ public:
+ using AvailabilityCallback =
+ std::function<void(const CastMediaSource& source,
+ const std::vector<ServiceInfo>& devices)>;
+
+ class Subscription {
+ public:
+ Subscription(Subscription&&);
+ ~Subscription();
+ Subscription& operator=(Subscription);
+
+ void Reset();
+
+ private:
+ friend class CastAppDiscoveryService;
+
+ Subscription(CastAppDiscoveryService* discovery_service, uint32_t id);
+
+ void Swap(Subscription& other);
+
+ CastAppDiscoveryService* discovery_service_;
+ uint32_t id_;
+ };
+
+ virtual ~CastAppDiscoveryService() = default;
+
+ // Adds an availability query for |source|. Results will be continuously
+ // returned via |callback| until the returned Subscription is destroyed by the
+ // caller. If there are cached results available, |callback| will be invoked
+ // before this method returns. |callback| may be invoked with an empty list
+ // if all devices respond to the respective queries with "unavailable" or
+ // don't respond before a timeout. |callback| may be invoked successively
+ // with the same list.
+ virtual Subscription StartObservingAvailability(
+ const CastMediaSource& source,
+ AvailabilityCallback callback) = 0;
+
+ // Refreshes the state of app discovery in the service. It is suitable to call
+ // this method when the user initiates a user gesture.
+ virtual void Refresh() = 0;
+
+ protected:
+ Subscription MakeSubscription(CastAppDiscoveryService* discovery_service,
+ uint32_t id) {
+ return Subscription(discovery_service, id);
+ }
+
+ private:
+ friend class Subscription;
+
+ virtual void RemoveAvailabilityCallback(uint32_t id) = 0;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_PUBLIC_CAST_APP_DISCOVERY_SERVICE_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc
new file mode 100644
index 00000000000..ebc9a3fe014
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.cc
@@ -0,0 +1,45 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/public/cast_media_source.h"
+
+#include <algorithm>
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+// static
+ErrorOr<CastMediaSource> CastMediaSource::From(const std::string& source) {
+ // TODO(btolsch): Implement when we have URL parsing.
+ OSP_UNIMPLEMENTED();
+ return Error::Code::kUnknownError;
+}
+
+CastMediaSource::CastMediaSource(std::string source,
+ std::vector<std::string> app_ids)
+ : source_id_(std::move(source)), app_ids_(std::move(app_ids)) {}
+
+CastMediaSource::CastMediaSource(const CastMediaSource& other) = default;
+CastMediaSource::CastMediaSource(CastMediaSource&& other) = default;
+
+CastMediaSource::~CastMediaSource() = default;
+
+CastMediaSource& CastMediaSource::operator=(const CastMediaSource& other) =
+ default;
+CastMediaSource& CastMediaSource::operator=(CastMediaSource&& other) = default;
+
+bool CastMediaSource::ContainsAppId(const std::string& app_id) const {
+ return std::find(app_ids_.begin(), app_ids_.end(), app_id) != app_ids_.end();
+}
+
+bool CastMediaSource::ContainsAnyAppIdFrom(
+ const std::vector<std::string>& app_ids) const {
+ return std::find_first_of(app_ids_.begin(), app_ids_.end(), app_ids.begin(),
+ app_ids.end()) != app_ids_.end();
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h
new file mode 100644
index 00000000000..18af80f8694
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/public/cast_media_source.h
@@ -0,0 +1,42 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_
+#define CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_
+
+#include <string>
+#include <vector>
+
+#include "platform/base/error.h"
+
+namespace openscreen {
+namespace cast {
+
+class CastMediaSource {
+ public:
+ static ErrorOr<CastMediaSource> From(const std::string& source);
+
+ CastMediaSource(std::string source, std::vector<std::string> app_ids);
+ CastMediaSource(const CastMediaSource& other);
+ CastMediaSource(CastMediaSource&& other);
+ ~CastMediaSource();
+
+ CastMediaSource& operator=(const CastMediaSource& other);
+ CastMediaSource& operator=(CastMediaSource&& other);
+
+ bool ContainsAppId(const std::string& app_id) const;
+ bool ContainsAnyAppIdFrom(const std::vector<std::string>& app_ids) const;
+
+ const std::string& source_id() const { return source_id_; }
+ const std::vector<std::string>& app_ids() const { return app_ids_; }
+
+ private:
+ std::string source_id_;
+ std::vector<std::string> app_ids_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_PUBLIC_CAST_MEDIA_SOURCE_H_
diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/DEPS b/chromium/third_party/openscreen/src/cast/sender/testing/DEPS
new file mode 100644
index 00000000000..99039c594a4
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/testing/DEPS
@@ -0,0 +1,4 @@
+include_rules = [
+ # Sender tests can use receiver code for simulation/validation.
+ '+cast/receiver',
+]
diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc
new file mode 100644
index 00000000000..ff9a45751ad
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.cc
@@ -0,0 +1,87 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/testing/test_helpers.h"
+
+#include "cast/common/channel/message_util.h"
+#include "cast/receiver/channel/message_util.h"
+#include "cast/sender/channel/message_util.h"
+#include "gtest/gtest.h"
+#include "util/json/json_serialization.h"
+#include "util/json/json_value.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+using ::cast::channel::CastMessage;
+
+void VerifyAppAvailabilityRequest(const CastMessage& message,
+ const std::string& expected_app_id,
+ int* request_id_out,
+ std::string* sender_id_out) {
+ std::string app_id_out;
+ VerifyAppAvailabilityRequest(message, &app_id_out, request_id_out,
+ sender_id_out);
+ EXPECT_EQ(app_id_out, expected_app_id);
+}
+
+void VerifyAppAvailabilityRequest(const CastMessage& message,
+ std::string* app_id_out,
+ int* request_id_out,
+ std::string* sender_id_out) {
+ EXPECT_EQ(message.namespace_(), kReceiverNamespace);
+ EXPECT_EQ(message.destination_id(), kPlatformReceiverId);
+ EXPECT_EQ(message.payload_type(),
+ ::cast::channel::CastMessage_PayloadType_STRING);
+ EXPECT_NE(message.source_id(), kPlatformSenderId);
+ *sender_id_out = message.source_id();
+
+ ErrorOr<Json::Value> maybe_value = json::Parse(message.payload_utf8());
+ ASSERT_TRUE(maybe_value);
+ Json::Value& value = maybe_value.value();
+
+ absl::optional<absl::string_view> maybe_type =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
+ ASSERT_TRUE(maybe_type);
+ EXPECT_EQ(maybe_type.value(),
+ CastMessageTypeToString(CastMessageType::kGetAppAvailability));
+
+ absl::optional<int> maybe_id =
+ MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyRequestId));
+ ASSERT_TRUE(maybe_id);
+ *request_id_out = maybe_id.value();
+
+ const Json::Value* maybe_app_ids =
+ value.find(JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyAppId));
+ ASSERT_TRUE(maybe_app_ids);
+ ASSERT_TRUE(maybe_app_ids->isArray());
+ ASSERT_EQ(maybe_app_ids->size(), 1u);
+ Json::Value app_id_value = maybe_app_ids->get(0u, Json::Value(""));
+ absl::optional<absl::string_view> maybe_app_id = MaybeGetString(app_id_value);
+ ASSERT_TRUE(maybe_app_id);
+ *app_id_out =
+ std::string(maybe_app_id.value().begin(), maybe_app_id.value().end());
+}
+
+CastMessage CreateAppAvailableResponseChecked(int request_id,
+ const std::string& sender_id,
+ const std::string& app_id) {
+ ErrorOr<CastMessage> message =
+ CreateAppAvailableResponse(request_id, sender_id, app_id);
+ OSP_CHECK(message);
+ return std::move(message.value());
+}
+
+CastMessage CreateAppUnavailableResponseChecked(int request_id,
+ const std::string& sender_id,
+ const std::string& app_id) {
+ ErrorOr<CastMessage> message =
+ CreateAppUnavailableResponse(request_id, sender_id, app_id);
+ OSP_CHECK(message);
+ return std::move(message.value());
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h
new file mode 100644
index 00000000000..c9a68c20e2a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/sender/testing/test_helpers.h
@@ -0,0 +1,43 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_TESTING_TEST_HELPERS_H_
+#define CAST_SENDER_TESTING_TEST_HELPERS_H_
+
+#include <cstdint>
+#include <string>
+
+#include "cast/sender/channel/message_util.h"
+
+namespace cast {
+namespace channel {
+class CastMessage;
+} // namespace channel
+} // namespace cast
+
+namespace openscreen {
+namespace cast {
+
+void VerifyAppAvailabilityRequest(const ::cast::channel::CastMessage& message,
+ const std::string& expected_app_id,
+ int* request_id_out,
+ std::string* sender_id_out);
+void VerifyAppAvailabilityRequest(const ::cast::channel::CastMessage& message,
+ std::string* app_id_out,
+ int* request_id_out,
+ std::string* sender_id_out);
+
+::cast::channel::CastMessage CreateAppAvailableResponseChecked(
+ int request_id,
+ const std::string& sender_id,
+ const std::string& app_id);
+::cast::channel::CastMessage CreateAppUnavailableResponseChecked(
+ int request_id,
+ const std::string& sender_id,
+ const std::string& app_id);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_SENDER_TESTING_TEST_HELPERS_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn
index 696f732ba3b..ba54d33f414 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn
@@ -2,39 +2,38 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
+import("//build/config/external_libraries.gni")
import("//build_overrides/build.gni")
-declare_args() {
- # These are only relevant for building the demo apps, which require external
- # headers/libraries be installed. Set them to true if your local system has
- # SDL2/FFMPEG installed. On Debian-like systems, the following should install
- # all the required headers and libraries:
- #
- # sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev \
- # libavformat libavformat-dev libavutil libavutil-dev
- have_sdl_for_demo_apps = false
- have_ffmpeg_for_demo_apps = false
-}
-
# Define the executable target only when the build is configured to use the
# standalone platform implementation; since this is itself a standalone
# application.
if (!build_with_chromium) {
executable("cast_receiver") {
sources = [
+ "cast_agent.cc",
+ "cast_agent.h",
+ "cast_socket_message_port.cc",
+ "cast_socket_message_port.h",
"main.cc",
- "memory_util.h",
+ "streaming_playback_controller.cc",
+ "streaming_playback_controller.h",
]
deps = [
"../../platform",
+ "../../third_party/jsoncpp",
+ "../common:public",
+ "../common/channel/proto:channel_proto",
+ "../receiver:channel",
"../streaming:receiver",
]
defines = []
include_dirs = []
+ lib_dirs = []
libs = []
- if (have_sdl_for_demo_apps && have_ffmpeg_for_demo_apps) {
- defines += [ "CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS" ]
+ if (have_ffmpeg && have_libsdl2) {
+ defines += [ "CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS" ]
sources += [
"avcodec_glue.h",
"decoder.cc",
@@ -48,11 +47,9 @@ if (!build_with_chromium) {
"sdl_video_player.cc",
"sdl_video_player.h",
]
- libs += [
- "-lSDL2",
- "-lavcodec",
- "-lavutil",
- ]
+ include_dirs += ffmpeg_include_dirs + libsdl2_include_dirs
+ lib_dirs += ffmpeg_lib_dirs + libsdl2_lib_dirs
+ libs += ffmpeg_libs + libsdl2_libs
} else {
sources += [
"dummy_player.cc",
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS b/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS
index 788c38cba93..f2040e13bd3 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/DEPS
@@ -5,4 +5,6 @@
include_rules = [
'+cast',
'+platform/impl',
+ '+discovery/common',
+ '+discovery/public',
]
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h
index d1dfae78075..aa516175d34 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/avcodec_glue.h
@@ -14,8 +14,8 @@ extern "C" {
#include <libavutil/samplefmt.h>
}
+namespace openscreen {
namespace cast {
-namespace streaming {
// Macro that, for an AVFoo, generates code for:
//
@@ -50,7 +50,7 @@ DEFINE_AV_UNIQUE_PTR(AVFrame, av_frame_alloc, av_frame_free(&obj));
#undef DEFINE_AV_UNIQUE_PTR
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_AVCODEC_GLUE_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc
new file mode 100644
index 00000000000..b614fa80889
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc
@@ -0,0 +1,162 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_receiver/cast_agent.h"
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "cast/standalone_receiver/cast_socket_message_port.h"
+#include "cast/standalone_receiver/private_key_der.h"
+#include "cast/streaming/constants.h"
+#include "cast/streaming/offer_messages.h"
+#include "platform/base/tls_credentials.h"
+#include "platform/base/tls_listen_options.h"
+#include "util/crypto/certificate_utils.h"
+#include "util/json/json_serialization.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+constexpr int kDefaultMaxBacklogSize = 64;
+const TlsListenOptions kDefaultListenOptions{kDefaultMaxBacklogSize};
+
+constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60;
+constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds);
+
+// Generates a valid set of credentials for use with the TLS Server socket,
+// including a generated X509 certificate generated from the static private key
+// stored in private_key_der.h. The certificate is valid for
+// kCertificateDuration from when this function is called.
+ErrorOr<TlsCredentials> CreateCredentials(const IPEndpoint& endpoint) {
+ ErrorOr<bssl::UniquePtr<EVP_PKEY>> private_key =
+ ImportRSAPrivateKey(kPrivateKeyDer.data(), kPrivateKeyDer.size());
+ OSP_CHECK(private_key);
+
+ ErrorOr<bssl::UniquePtr<X509>> cert = CreateSelfSignedX509Certificate(
+ endpoint.ToString(), kCertificateDuration, *private_key.value());
+ if (!cert) {
+ return cert.error();
+ }
+
+ auto cert_bytes = ExportX509CertificateToDer(*cert.value());
+ if (!cert_bytes) {
+ return cert_bytes.error();
+ }
+
+ // TODO(jophba): either refactor the TLS server socket to use the public key
+ // and add a valid key here, or remove from the TlsCredentials struct.
+ return TlsCredentials(
+ std::vector<uint8_t>(kPrivateKeyDer.begin(), kPrivateKeyDer.end()),
+ std::vector<uint8_t>{}, std::move(cert_bytes.value()));
+}
+
+} // namespace
+
+CastAgent::CastAgent(TaskRunner* task_runner, InterfaceInfo interface)
+ : task_runner_(task_runner) {
+ // Create the Environment that holds the required injected dependencies
+ // (clock, task runner) used throughout the system, and owns the UDP socket
+ // over which all communication occurs with the Sender.
+ IPEndpoint receive_endpoint{IPAddress::kV4LoopbackAddress, kDefaultCastPort};
+ receive_endpoint.address = interface.GetIpAddressV4()
+ ? interface.GetIpAddressV4()
+ : interface.GetIpAddressV6();
+ OSP_DCHECK(receive_endpoint.address);
+ environment_ = std::make_unique<Environment>(&Clock::now, task_runner_,
+ receive_endpoint);
+ receive_endpoint_ = std::move(receive_endpoint);
+}
+
+CastAgent::~CastAgent() = default;
+
+Error CastAgent::Start() {
+ OSP_DCHECK(!current_session_);
+
+ task_runner_->PostTask(
+ [this] { this->wake_lock_ = ScopedWakeLock::Create(); });
+
+ // TODO(jophba): add command line argument for setting the private key.
+ ErrorOr<TlsCredentials> credentials = CreateCredentials(receive_endpoint_);
+ if (!credentials) {
+ return credentials.error();
+ }
+
+ // TODO(jophba, rwkeane): begin discovery process before creating TLS
+ // connection factory instance.
+ socket_factory_ =
+ std::make_unique<ReceiverSocketFactory>(this, &message_port_);
+ task_runner_->PostTask([this, creds = std::move(credentials.value())] {
+ connection_factory_ = TlsConnectionFactory::CreateFactory(
+ socket_factory_.get(), task_runner_);
+ connection_factory_->SetListenCredentials(creds);
+ connection_factory_->Listen(receive_endpoint_, kDefaultListenOptions);
+ });
+
+ OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_;
+ return Error::None();
+}
+
+Error CastAgent::Stop() {
+ controller_.reset();
+ current_session_.reset();
+ return Error::None();
+}
+
+void CastAgent::OnConnected(ReceiverSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) {
+ OSP_DCHECK(factory);
+
+ if (current_session_) {
+ OSP_LOG_WARN << "Already connected, dropping peer at: " << endpoint;
+ return;
+ }
+
+ OSP_LOG_INFO << "Received connection from peer at: " << endpoint;
+ message_port_.SetSocket(std::move(socket));
+ controller_ =
+ std::make_unique<StreamingPlaybackController>(task_runner_, this);
+ current_session_ = std::make_unique<ReceiverSession>(
+ controller_.get(), environment_.get(), &message_port_,
+ ReceiverSession::Preferences{});
+}
+
+void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) {
+ OSP_LOG_ERROR << "Cast agent received socket factory error: " << error;
+}
+
+// Currently we don't do anything with the receiver output--the session
+// is automatically linked to the playback controller when it is constructed, so
+// we don't actually have to interface with the receivers. If we end up caring
+// about the receiver configurations we will have to handle OnNegotiated here.
+void CastAgent::OnNegotiated(const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers receivers) {
+ OSP_LOG_INFO << "Successfully negotiated with sender.";
+}
+
+void CastAgent::OnConfiguredReceiversDestroyed(const ReceiverSession* session) {
+ OSP_LOG_INFO << "Receiver instances destroyed.";
+}
+
+// Currently, we just kill the session if an error is encountered.
+void CastAgent::OnError(const ReceiverSession* session, Error error) {
+ OSP_LOG_ERROR << "Cast agent received receiver session error: " << error;
+ current_session_.reset();
+}
+
+void CastAgent::OnPlaybackError(StreamingPlaybackController* controller,
+ Error error) {
+ OSP_LOG_ERROR << "Cast agent received playback error: " << error;
+ current_session_.reset();
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h
new file mode 100644
index 00000000000..873563e6604
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h
@@ -0,0 +1,84 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_RECEIVER_CAST_AGENT_H_
+#define CAST_STANDALONE_RECEIVER_CAST_AGENT_H_
+
+#include <openssl/x509.h>
+
+#include <memory>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/receiver/channel/receiver_socket_factory.h"
+#include "cast/standalone_receiver/cast_socket_message_port.h"
+#include "cast/standalone_receiver/streaming_playback_controller.h"
+#include "cast/streaming/environment.h"
+#include "cast/streaming/receiver_session.h"
+#include "platform/api/scoped_wake_lock.h"
+#include "platform/base/error.h"
+#include "platform/base/interface_info.h"
+#include "platform/impl/task_runner.h"
+#include "util/serial_delete_ptr.h"
+
+namespace openscreen {
+namespace cast {
+
+// This class manages sender connections, starting with listening over TLS for
+// connection attempts, constructing ReceiverSessions when OFFER messages are
+// received, and linking Receivers to the output decoder and SDL visualizer.
+//
+// Consumers of this class are expected to provide a single threaded task runner
+// implementation, and a network interface information struct that will be used
+// both for TLS listening and UDP messaging.
+class CastAgent : public ReceiverSocketFactory::Client,
+ public ReceiverSession::Client,
+ public StreamingPlaybackController::Client {
+ public:
+ CastAgent(TaskRunner* task_runner, InterfaceInfo interface);
+ ~CastAgent();
+
+ // Initialization occurs as part of construction, however to actually bind
+ // for discovery and listening over TLS, the CastAgent must be started.
+ Error Start();
+ Error Stop();
+
+ // ReceiverSocketFactory::Client overrides.
+ void OnConnected(ReceiverSocketFactory* factory,
+ const IPEndpoint& endpoint,
+ std::unique_ptr<CastSocket> socket) override;
+ void OnError(ReceiverSocketFactory* factory, Error error) override;
+
+ // ReceiverSession::Client overrides.
+ void OnNegotiated(const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers receivers) override;
+ void OnConfiguredReceiversDestroyed(const ReceiverSession* session) override;
+ void OnError(const ReceiverSession* session, Error error) override;
+
+ // StreamingPlaybackController::Client overrides
+ void OnPlaybackError(StreamingPlaybackController* controller,
+ Error error) override;
+
+ private:
+ // Member variables set as part of construction.
+ std::unique_ptr<Environment> environment_;
+ TaskRunner* const task_runner_;
+ IPEndpoint receive_endpoint_;
+ CastSocketMessagePort message_port_;
+
+ // Member variables set as part of starting up.
+ std::unique_ptr<TlsConnectionFactory> connection_factory_;
+ std::unique_ptr<ReceiverSocketFactory> socket_factory_;
+ std::unique_ptr<ScopedWakeLock> wake_lock_;
+
+ // Member variables set as part of a sender connection.
+ // NOTE: currently we only support a single sender connection and a
+ // single streaming session.
+ std::unique_ptr<ReceiverSession> current_session_;
+ std::unique_ptr<StreamingPlaybackController> controller_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_RECEIVER_CAST_AGENT_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc
new file mode 100644
index 00000000000..706fd58ff6d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc
@@ -0,0 +1,59 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_receiver/cast_socket_message_port.h"
+
+#include <utility>
+
+namespace openscreen {
+namespace cast {
+
+CastSocketMessagePort::CastSocketMessagePort() = default;
+CastSocketMessagePort::~CastSocketMessagePort() = default;
+
+// NOTE: we assume here that this message port is already the client for
+// the passed in socket, so leave the socket's client unchanged. However,
+// since sockets should map one to one with receiver sessions, we reset our
+// client. The consumer of this message port should call SetClient with the new
+// message port client after setting the socket.
+void CastSocketMessagePort::SetSocket(std::unique_ptr<CastSocket> socket) {
+ client_ = nullptr;
+ socket_ = std::move(socket);
+}
+
+void CastSocketMessagePort::SetClient(MessagePort::Client* client) {
+ client_ = client;
+}
+
+void CastSocketMessagePort::OnError(CastSocket* socket, Error error) {
+ if (client_) {
+ client_->OnError(error);
+ }
+}
+
+void CastSocketMessagePort::OnMessage(CastSocket* socket,
+ ::cast::channel::CastMessage message) {
+ if (client_) {
+ client_->OnMessage(message.source_id(), message.namespace_(),
+ message.payload_utf8());
+ }
+}
+
+void CastSocketMessagePort::PostMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) {
+ ::cast::channel::CastMessage cast_message;
+ cast_message.set_source_id(sender_id.data(), sender_id.size());
+ cast_message.set_namespace_(message_namespace.data(),
+ message_namespace.size());
+ cast_message.set_payload_utf8(message.data(), message.size());
+
+ Error error = socket_->SendMessage(cast_message);
+ if (!error.ok()) {
+ client_->OnError(error);
+ }
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h
new file mode 100644
index 00000000000..c596f985705
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h
@@ -0,0 +1,44 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_
+#define CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/streaming/receiver_session.h"
+
+namespace openscreen {
+namespace cast {
+
+class CastSocketMessagePort : public MessagePort, public CastSocket::Client {
+ public:
+ CastSocketMessagePort();
+ ~CastSocketMessagePort() override;
+
+ void SetSocket(std::unique_ptr<CastSocket> socket);
+
+ // MessagePort overrides.
+ void SetClient(MessagePort::Client* client) override;
+ void PostMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) override;
+
+ // CastSocket::Client overrides.
+ void OnError(CastSocket* socket, Error error) override;
+ void OnMessage(CastSocket* socket,
+ ::cast::channel::CastMessage message) override;
+
+ private:
+ MessagePort::Client* client_ = nullptr;
+ std::unique_ptr<CastSocket> socket_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_RECEIVER_CAST_SOCKET_MESSAGE_PORT_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc
index 8421cd0f2e8..5d84d6a77fc 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc
@@ -1,11 +1,18 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
#include "cast/standalone_receiver/decoder.h"
+#include <algorithm>
#include <sstream>
+#include <thread> // NOLINT
#include "util/logging.h"
+#include "util/trace_logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
Decoder::Buffer::Buffer() {
Resize(0);
@@ -36,16 +43,14 @@ absl::Span<uint8_t> Decoder::Buffer::GetSpan() {
Decoder::Client::Client() = default;
Decoder::Client::~Client() = default;
-Decoder::Decoder() = default;
+Decoder::Decoder(const std::string& codec_name) : codec_name_(codec_name) {}
+
Decoder::~Decoder() = default;
void Decoder::Decode(FrameId frame_id, const Decoder::Buffer& buffer) {
- if (!codec_) {
- InitFromFirstBuffer(buffer);
- if (!codec_) {
- OnError("unable to detect codec", AVERROR(EINVAL), frame_id);
- return;
- }
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
+ if (!codec_ && !Initialize()) {
+ return;
}
// Parse the buffer for the required metadata and the packet to send to the
@@ -95,24 +100,64 @@ void Decoder::Decode(FrameId frame_id, const Decoder::Buffer& buffer) {
}
}
-void Decoder::InitFromFirstBuffer(const Buffer& buffer) {
- const AVCodecID codec_id = Detect(buffer);
- if (codec_id == AV_CODEC_ID_NONE) {
- return;
+bool Decoder::Initialize() {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
+ // NOTE: The codec_name values found in OFFER messages, such as "vp8" or
+ // "h264" or "opus" are valid input strings to FFMPEG's look-up function, so
+ // no translation is required here.
+ codec_ = avcodec_find_decoder_by_name(codec_name_.c_str());
+ if (!codec_) {
+ HandleInitializationError("codec not available", AVERROR(EINVAL));
+ return false;
+ }
+ OSP_LOG_INFO << "Found codec: " << codec_name_ << " (known to FFMPEG as "
+ << avcodec_get_name(codec_->id) << ')';
+
+ parser_ = MakeUniqueAVCodecParserContext(codec_->id);
+ if (!parser_) {
+ HandleInitializationError("failed to allocate parser context",
+ AVERROR(ENOMEM));
+ return false;
}
- codec_ = avcodec_find_decoder(codec_id);
- OSP_CHECK(codec_);
- parser_ = MakeUniqueAVCodecParserContext(codec_id);
- OSP_CHECK(parser_);
context_ = MakeUniqueAVCodecContext(codec_);
- OSP_CHECK(context_);
+ if (!context_) {
+ HandleInitializationError("failed to allocate codec context",
+ AVERROR(ENOMEM));
+ return false;
+ }
+
+ // This should always be greater than zero, so that decoding doesn't block the
+ // main thread of this receiver app and cause playback timing issues. The
+ // actual number should be tuned, based on the number of CPU cores.
+ //
+ // This should also be 16 or less, since the encoder implementations emit
+ // warnings about too many encode threads. FFMPEG's VP8 implementation
+ // actually silently freezes if this is 10 or more. Thus, 8 is used for the
+ // max here, just to be safe.
+ //
+ // TODO(jophba): determine a better number after running benchmarking.
+ context_->thread_count =
+ std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8);
const int open_result = avcodec_open2(context_.get(), codec_, nullptr);
- OSP_CHECK_EQ(open_result, 0);
+ if (open_result < 0) {
+ HandleInitializationError("failed to open codec", open_result);
+ return false;
+ }
+
packet_ = MakeUniqueAVPacket();
- OSP_CHECK(packet_);
+ if (!packet_) {
+ HandleInitializationError("failed to allocate AVPacket", AVERROR(ENOMEM));
+ return false;
+ }
+
decoded_frame_ = MakeUniqueAVFrame();
- OSP_CHECK(decoded_frame_);
+ if (!decoded_frame_) {
+ HandleInitializationError("failed to allocate AVFrame", AVERROR(ENOMEM));
+ return false;
+ }
+
+ return true;
}
FrameId Decoder::DidReceiveFrameFromDecoder() {
@@ -123,6 +168,26 @@ FrameId Decoder::DidReceiveFrameFromDecoder() {
return frame_id;
}
+void Decoder::HandleInitializationError(const char* what, int av_errnum) {
+ // If the codec was found, get FFMPEG's canonical name for it.
+ const char* const canonical_name =
+ codec_ ? avcodec_get_name(codec_->id) : nullptr;
+
+ codec_ = nullptr; // Set null to mean "not initialized."
+
+ if (!client_) {
+ return; // Nowhere to emit error to, so don't bother.
+ }
+
+ std::ostringstream error;
+ error << "Could not initialize codec " << codec_name_;
+ if (canonical_name) {
+ error << " (known to FFMPEG as " << canonical_name << ')';
+ }
+ error << " because " << what << " (" << av_err2str(av_errnum) << ").";
+ client_->OnFatalError(error.str());
+}
+
void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) {
if (!client_) {
return;
@@ -133,7 +198,11 @@ void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) {
if (!frame_id.is_null()) {
error << "frame: " << frame_id << "; ";
}
- error << "what: " << what << "; error: " << av_err2str(av_errnum);
+
+ char human_readable_error[AV_ERROR_MAX_STRING_SIZE]{0};
+ av_make_error_string(human_readable_error, AV_ERROR_MAX_STRING_SIZE,
+ av_errnum);
+ error << "what: " << what << "; error: " << human_readable_error;
// Dispatch to either the fatal error handler, or the one for decode errors,
// as appropriate.
@@ -149,52 +218,5 @@ void Decoder::OnError(const char* what, int av_errnum, FrameId frame_id) {
}
}
-// static
-AVCodecID Decoder::Detect(const Buffer& buffer) {
- static constexpr AVCodecID kCodecsToTry[] = {
- AV_CODEC_ID_VP8, AV_CODEC_ID_VP9, AV_CODEC_ID_H264,
- AV_CODEC_ID_H265, AV_CODEC_ID_OPUS, AV_CODEC_ID_FLAC,
- };
-
- const absl::Span<const uint8_t> input = buffer.GetSpan();
- for (AVCodecID codec_id : kCodecsToTry) {
- AVCodec* const codec = avcodec_find_decoder(codec_id);
- if (!codec) {
- OSP_LOG_WARN << "Video codec not available to try: "
- << avcodec_get_name(codec_id);
- continue;
- }
- const auto parser = MakeUniqueAVCodecParserContext(codec_id);
- if (!parser) {
- OSP_LOG_ERROR << "Failed to init parser for codec: "
- << avcodec_get_name(codec_id);
- continue;
- }
- const auto context = MakeUniqueAVCodecContext(codec);
- if (!context || avcodec_open2(context.get(), codec, nullptr) != 0) {
- OSP_LOG_ERROR << "Failed to open codec: " << avcodec_get_name(codec_id);
- continue;
- }
- const auto packet = MakeUniqueAVPacket();
- OSP_CHECK(packet);
- if (av_parser_parse2(parser.get(), context.get(), &packet->data,
- &packet->size, input.data(), input.size(),
- AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0) < 0 ||
- !packet->data || packet->size == 0) {
- OSP_VLOG << "Does not parse as codec: " << avcodec_get_name(codec_id);
- continue;
- }
- if (avcodec_send_packet(context.get(), packet.get()) < 0) {
- OSP_VLOG << "avcodec_send_packet() failed, probably wrong codec version: "
- << avcodec_get_name(codec_id);
- continue;
- }
- OSP_VLOG << "Detected codec: " << avcodec_get_name(codec_id);
- return codec_id;
- }
-
- return AV_CODEC_ID_NONE;
-}
-
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h
index 08d61e4335c..1d4d07917eb 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.h
@@ -14,10 +14,10 @@
#include "cast/standalone_receiver/avcodec_glue.h"
#include "cast/streaming/frame_id.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
-// Wraps libavcodec to auto-detect and decode audio or video.
+// Wraps libavcodec to decode audio or video.
class Decoder {
public:
// A buffer backed by storage that is compatible with FFMPEG (i.e., includes
@@ -48,7 +48,8 @@ class Decoder {
Client();
};
- Decoder();
+ // |codec_name| should be the codec_name field from an OFFER message.
+ explicit Decoder(const std::string& codec_name);
~Decoder();
Client* client() const { return client_; }
@@ -62,21 +63,23 @@ class Decoder {
void Decode(FrameId frame_id, const Buffer& buffer);
private:
- // Helper to auto-detect the codec being used and initialize the FFMPEG
- // decoder; called for the first frame being decoded.
- void InitFromFirstBuffer(const Buffer& buffer);
+ // Helper to initialize the FFMPEG decoder and supporting objects. Returns
+ // false if this failed (and the Client was notified).
+ bool Initialize();
// Helper to get the FrameId that is associated with the next frame coming out
// of the FFMPEG decoder.
FrameId DidReceiveFrameFromDecoder();
- // Called when any transient or fatal error occurs, generating an
- // openscreen::Error and notifying the Client of it.
- void OnError(const char* what, int av_errnum, FrameId frame_id);
+ // Helper to handle a codec initialization error and notify the Client of the
+ // fatal error.
+ void HandleInitializationError(const char* what, int av_errnum);
- // Auto-detects the codec needed to decode the data in |buffer|.
- static AVCodecID Detect(const Buffer& buffer);
+ // Called when any transient or fatal error occurs, generating an Error and
+ // notifying the Client of it.
+ void OnError(const char* what, int av_errnum, FrameId frame_id);
+ const std::string codec_name_;
AVCodec* codec_ = nullptr;
AVCodecParserContextUniquePtr parser_;
AVCodecContextUniquePtr context_;
@@ -90,7 +93,7 @@ class Decoder {
std::vector<FrameId> frames_decoding_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_DECODER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc
index 8d2af1f74ba..cce96ee816b 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc
@@ -12,8 +12,8 @@
using std::chrono::microseconds;
+namespace openscreen {
namespace cast {
-namespace streaming {
DummyPlayer::DummyPlayer(Receiver* receiver) : receiver_(receiver) {
OSP_DCHECK(receiver_);
@@ -41,5 +41,5 @@ void DummyPlayer::OnFramesReady(int buffer_size) {
<< buffer_size << " bytes";
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h
index 373d32329fd..e8db1bf7a49 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.h
@@ -13,14 +13,14 @@
#include "platform/api/task_runner.h"
#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Consumes frames from a Receiver, but does nothing other than OSP_LOG_INFO
// each one's FrameId, timestamp and size. This is only useful for confirming a
// Receiver is successfully receiving a stream, for platforms where
// SDLVideoPlayer cannot be built.
-class DummyPlayer : public Receiver::Consumer {
+class DummyPlayer final : public Receiver::Consumer {
public:
explicit DummyPlayer(Receiver* receiver);
@@ -34,7 +34,7 @@ class DummyPlayer : public Receiver::Consumer {
std::vector<uint8_t> buffer_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_DUMMY_PLAYER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh
new file mode 100755
index 00000000000..c082455cd41
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_debian.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env sh
+
+# Installs dependencies necessary for libSDL and libAVcodec on Debian systems.
+
+sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev \
+ libavformat libavformat-dev libavutil libavutil-dev \
+ libswresample libswresample-dev \ No newline at end of file
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh
new file mode 100755
index 00000000000..91acaaa6bfb
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/install_demo_deps_raspian.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env sh
+
+# Installs dependencies necessary for libSDL and libAVcodec on
+# Raspberry PI units running Raspian.
+
+sudo apt-get install libavcodec58=7:4.1.4* libavcodec-dev=7:4.1.4* \
+ libsdl2-2.0-0=2.0.9* libsdl2-dev=2.0.9* \
+ libavformat-dev=7:4.1.4*
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc
index 76e7a8d2c54..ed2dffd913d 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc
@@ -2,181 +2,193 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <getopt.h>
+
#include <array>
-#include <chrono>
-#include <thread>
+#include <chrono> // NOLINT
-#include "cast/streaming/constants.h"
-#include "cast/streaming/environment.h"
-#include "cast/streaming/receiver.h"
-#include "cast/streaming/receiver_packet_router.h"
+#include "cast/common/public/service_info.h"
+#include "cast/standalone_receiver/cast_agent.h"
#include "cast/streaming/ssrc.h"
+#include "discovery/common/config.h"
+#include "discovery/common/reporting_client.h"
+#include "discovery/public/dns_sd_service_factory.h"
+#include "discovery/public/dns_sd_service_publisher.h"
#include "platform/api/time.h"
#include "platform/api/udp_socket.h"
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
#include "platform/impl/logging.h"
+#include "platform/impl/network_interface.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/task_runner.h"
+#include "platform/impl/text_trace_logging_platform.h"
+#include "util/stringprintf.h"
+#include "util/trace_logging.h"
-#if defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS)
-#include "cast/standalone_receiver/sdl_audio_player.h"
-#include "cast/standalone_receiver/sdl_glue.h"
-#include "cast/standalone_receiver/sdl_video_player.h"
-#else
-#include "cast/standalone_receiver/dummy_player.h"
-#endif // defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS)
-
-using openscreen::IPEndpoint;
-using openscreen::platform::Clock;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::TaskRunnerImpl;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
-////////////////////////////////////////////////////////////////////////////////
-// Receiver Configuration
-//
-// The values defined here are constants that correspond to the Sender Demo app.
-// In a production environment, these should ABSOLUTELY NOT be fixed! Instead a
-// sender↔receiver OFFER/ANSWER exchange should establish them.
-
-const cast::streaming::SessionConfig kSampleAudioAnswerConfig{
- /* .sender_ssrc = */ 1,
- /* .receiver_ssrc = */ 2,
-
- // In a production environment, this would be set to the sampling rate of
- // the audio capture.
- /* .rtp_timebase = */ 48000,
- /* .channels = */ 2,
- /* .aes_secret_key = */
- {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
- 0x0c, 0x0d, 0x0e, 0x0f},
- /* .aes_iv_mask = */
- {0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80, 0x70, 0x60, 0x50, 0x40,
- 0x30, 0x20, 0x10, 0x00},
+class DiscoveryReportingClient : public discovery::ReportingClient {
+ void OnFatalError(Error error) override {
+ OSP_LOG_FATAL << "Encountered fatal discovery error: " << error;
+ }
+
+ void OnRecoverableError(Error error) override {
+ OSP_LOG_ERROR << "Encountered recoverable discovery error: " << error;
+ }
};
-const cast::streaming::SessionConfig kSampleVideoAnswerConfig{
- /* .sender_ssrc = */ 50001,
- /* .receiver_ssrc = */ 50002,
- /* .rtp_timebase = */ static_cast<int>(kVideoTimebase::den),
- /* .channels = */ 1,
- /* .aes_secret_key = */
- {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
- 0x1c, 0x1d, 0x1e, 0x1f},
- /* .aes_iv_mask = */
- {0xf1, 0xe1, 0xd1, 0xc1, 0xb1, 0xa1, 0x91, 0x81, 0x71, 0x61, 0x51, 0x41,
- 0x31, 0x21, 0x11, 0x01},
+struct DiscoveryState {
+ SerialDeletePtr<discovery::DnsSdService> service;
+ std::unique_ptr<DiscoveryReportingClient> reporting_client;
+ std::unique_ptr<discovery::DnsSdServicePublisher<ServiceInfo>> publisher;
};
-// In a production environment, this would start-out at some initial value
-// appropriate to the networking environment, and then be adjusted by the
-// application as: 1) the TYPE of the content changes (interactive, low-latency
-// versus smooth, higher-latency buffered video watching); and 2) the networking
-// environment reliability changes.
-constexpr std::chrono::milliseconds kDemoTargetPlayoutDelay{400};
-
-// The UDP socket port receiving packets from the Sender.
-constexpr int kCastStreamingPort = 2344;
-
-// End of Receiver Configuration.
-////////////////////////////////////////////////////////////////////////////////
-
-void DemoMain(TaskRunnerImpl* task_runner) {
- // Create the Environment that holds the required injected dependencies
- // (clock, task runner) used throughout the system, and owns the UDP socket
- // over which all communication occurs with the Sender.
- const IPEndpoint receive_endpoint{openscreen::IPAddress(),
- kCastStreamingPort};
- Environment env(&Clock::now, task_runner, receive_endpoint);
-
- // Create the packet router that allows both the Audio Receiver and the Video
- // Receiver to share the same UDP socket.
- ReceiverPacketRouter packet_router(&env);
-
- // Create the two Receivers.
- Receiver audio_receiver(&env, &packet_router, kSampleAudioAnswerConfig,
- kDemoTargetPlayoutDelay);
- Receiver video_receiver(&env, &packet_router, kSampleVideoAnswerConfig,
- kDemoTargetPlayoutDelay);
-
- OSP_LOG_INFO << "Awaiting first Cast Streaming packet at "
- << env.GetBoundLocalEndpoint() << "...";
-
-#if defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS)
-
- // Start the SDL event loop, using the task runner to poll/process events.
- const ScopedSDLSubSystem<SDL_INIT_AUDIO> sdl_audio_sub_system;
- const ScopedSDLSubSystem<SDL_INIT_VIDEO> sdl_video_sub_system;
- const SDLEventLoopProcessor sdl_event_loop(
- task_runner, [&] { task_runner->RequestStopSoon(); });
-
- // Create/Initialize the Audio Player and Video Player, which are responsible
- // for decoding and playing out the received media.
- constexpr int kDefaultWindowWidth = 1280;
- constexpr int kDefaultWindowHeight = 720;
- const SDLWindowUniquePtr window = MakeUniqueSDLWindow(
- "Cast Streaming Receiver Demo",
- SDL_WINDOWPOS_UNDEFINED /* initial X position */,
- SDL_WINDOWPOS_UNDEFINED /* initial Y position */, kDefaultWindowWidth,
- kDefaultWindowHeight, SDL_WINDOW_RESIZABLE);
- OSP_CHECK(window) << "Failed to create SDL window: " << SDL_GetError();
- const SDLRendererUniquePtr renderer =
- MakeUniqueSDLRenderer(window.get(), -1, 0);
- OSP_CHECK(renderer) << "Failed to create SDL renderer: " << SDL_GetError();
-
- const SDLAudioPlayer audio_player(
- &Clock::now, task_runner, &audio_receiver, [&] {
- OSP_LOG_ERROR << audio_player.error_status().message();
- task_runner->RequestStopSoon();
- });
- const SDLVideoPlayer video_player(
- &Clock::now, task_runner, &video_receiver, renderer.get(), [&] {
- OSP_LOG_ERROR << video_player.error_status().message();
- task_runner->RequestStopSoon();
- });
-
-#else
-
- const DummyPlayer audio_player(&audio_receiver);
- const DummyPlayer video_player(&video_receiver);
-
-#endif // defined(CAST_STREAMING_HAVE_EXTERNAL_LIBS_FOR_DEMO_APPS)
+ErrorOr<std::unique_ptr<DiscoveryState>> StartDiscovery(
+ TaskRunner* task_runner,
+ const InterfaceInfo& interface) {
+ discovery::Config config;
+
+ config.interface = interface;
+
+ auto state = std::make_unique<DiscoveryState>();
+ state->reporting_client = std::make_unique<DiscoveryReportingClient>();
+ state->service = discovery::CreateDnsSdService(
+ task_runner, state->reporting_client.get(), config);
+
+ // TODO(jophba): update after ServiceInfo update patch lands.
+ ServiceInfo info;
+ info.port = kDefaultCastPort;
+ if (interface.GetIpAddressV4()) {
+ info.v4_address = interface.GetIpAddressV4();
+ }
+ if (interface.GetIpAddressV6()) {
+ info.v6_address = interface.GetIpAddressV6();
+ }
+
+ OSP_CHECK(std::any_of(interface.hardware_address.begin(),
+ interface.hardware_address.end(),
+ [](int e) { return e > 0; }));
+ info.unique_id = HexEncode(interface.hardware_address);
+
+ // TODO(jophba): add command line arguments to set these fields.
+ info.model_name = "cast_standalone_receiver";
+ info.friendly_name = "Cast Standalone Receiver";
+
+ state->publisher =
+ std::make_unique<discovery::DnsSdServicePublisher<ServiceInfo>>(
+ state->service.get(), kCastV2ServiceId, ServiceInfoToDnsSdRecord);
+
+ auto error = state->publisher->Register(info);
+ if (!error.ok()) {
+ return error;
+ }
+ return state;
+}
+
+void RunStandaloneReceiver(TaskRunnerImpl* task_runner,
+ InterfaceInfo interface) {
+ CastAgent agent(task_runner, interface);
+ const auto error = agent.Start();
+ if (!error.ok()) {
+ OSP_LOG_ERROR << "Error occurred while starting agent: " << error;
+ return;
+ }
// Run the event loop until an exit is requested (e.g., the video player GUI
- // window is closed, a SIGTERM is intercepted, or whatever other appropriate
- // user indication that shutdown is requested).
- task_runner->RunUntilStopped();
+ // window is closed, a SIGINT or SIGTERM is received, or whatever other
+ // appropriate user indication that shutdown is requested).
+ task_runner->RunUntilSignaled();
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
+
+namespace {
+
+void LogUsage(const char* argv0) {
+ constexpr char kExecutableTag[] = "argv[0]";
+ constexpr char kUsageMessage[] = R"(
+ usage: argv[0] <options> <interface>
-int main(int argc, const char* argv[]) {
- using openscreen::platform::PlatformClientPosix;
-
- class PlatformClientExposingTaskRunner : public PlatformClientPosix {
- public:
- explicit PlatformClientExposingTaskRunner(
- std::unique_ptr<TaskRunner> task_runner)
- : PlatformClientPosix(Clock::duration{50},
- Clock::duration{50},
- std::move(task_runner)) {
- SetInstance(this);
+ options:
+ <interface>: Specify the network interface to bind to. The interface is
+ looked up from the system interface registry. This argument is
+ mandatory, as it must be known for publishing discovery.
+
+ -t, --tracing: Enable performance tracing logging.
+
+ -h, --help: Show this help message.
+ )";
+ std::string message = kUsageMessage;
+ message.replace(message.find(kExecutableTag), strlen(kExecutableTag), argv0);
+ OSP_LOG_INFO << message;
+}
+
+} // namespace
+
+int main(int argc, char* argv[]) {
+ // TODO(jophba): refactor into separate method and make main a one-liner.
+ using openscreen::Clock;
+ using openscreen::ErrorOr;
+ using openscreen::InterfaceInfo;
+ using openscreen::IPAddress;
+ using openscreen::IPEndpoint;
+ using openscreen::PlatformClientPosix;
+ using openscreen::TaskRunnerImpl;
+
+ openscreen::SetLogLevel(openscreen::LogLevel::kInfo);
+
+ const struct option argument_options[] = {
+ {"tracing", no_argument, nullptr, 't'},
+ {"help", no_argument, nullptr, 'h'},
+ {nullptr, 0, nullptr, 0}};
+
+ InterfaceInfo interface_info;
+ std::unique_ptr<openscreen::TextTraceLoggingPlatform> trace_logger;
+ int ch = -1;
+ while ((ch = getopt_long(argc, argv, "th", argument_options, nullptr)) !=
+ -1) {
+ switch (ch) {
+ case 't':
+ trace_logger = std::make_unique<openscreen::TextTraceLoggingPlatform>();
+ break;
+ case 'h':
+ LogUsage(argv[0]);
+ return 1;
}
- };
+ }
+ char* interface_argument = argv[optind];
+ OSP_CHECK(interface_argument != nullptr)
+ << "Missing mandatory argument: interface.";
+ std::vector<InterfaceInfo> network_interfaces =
+ openscreen::GetNetworkInterfaces();
+ for (auto& interface : network_interfaces) {
+ if (interface.name == interface_argument) {
+ interface_info = std::move(interface);
+ break;
+ }
+ }
+ OSP_CHECK(!interface_info.name.empty()) << "Invalid interface specified.";
+
+ auto* const task_runner = new TaskRunnerImpl(&Clock::now);
+ PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50},
+ std::unique_ptr<TaskRunnerImpl>(task_runner));
- openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kInfo);
- auto* const platform_client = new PlatformClientExposingTaskRunner(
- std::make_unique<TaskRunnerImpl>(&Clock::now));
+ auto discovery_state =
+ openscreen::cast::StartDiscovery(task_runner, interface_info);
+ OSP_CHECK(discovery_state.is_value()) << "Failed to start discovery.";
- cast::streaming::DemoMain(static_cast<TaskRunnerImpl*>(
- PlatformClientPosix::GetInstance()->GetTaskRunner()));
+ // Runs until the process is interrupted. Safe to pass |task_runner| as it
+ // will not be destroyed by ShutDown() until this exits.
+ openscreen::cast::RunStandaloneReceiver(task_runner, interface_info);
- platform_client->ShutDown(); // Deletes |platform_client|.
+ // The task runner must be deleted after all serial delete pointers, such
+ // as the one stored in the discovery state.
+ discovery_state.value().reset();
+ PlatformClientPosix::ShutDown();
return 0;
}
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der b/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der
new file mode 100644
index 00000000000..48aa17b0151
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/private.der
Binary files differ
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h
new file mode 100644
index 00000000000..0d9a328f01c
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/private_key_der.h
@@ -0,0 +1,125 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_
+#define CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_
+
+#include <array>
+
+namespace openscreen {
+namespace cast {
+
+// Important note about private keys and security: For example usage purposes,
+// we have checked in a default private key here. However, in a production
+// environment keys should never be checked into source control. This is an
+// example self-signed private key for TLS.
+//
+// Generated using the following command:
+// $ xxd -i <path/to/private_key.der>
+std::array<uint8_t, 1192> kPrivateKeyDer = {
+ 0x30, 0x82, 0x04, 0xa4, 0x02, 0x01, 0x00, 0x02, 0x82, 0x01, 0x01, 0x00,
+ 0xb8, 0x07, 0xbb, 0x3f, 0xde, 0x47, 0xba, 0xec, 0x7a, 0xc2, 0x6e, 0xe5,
+ 0x44, 0x4a, 0xa8, 0x50, 0xd9, 0xef, 0xc1, 0x67, 0xc1, 0x5e, 0x14, 0xc0,
+ 0x1a, 0x1f, 0x8e, 0x83, 0xb9, 0xb2, 0x5d, 0x2d, 0x74, 0x12, 0x4b, 0x43,
+ 0x3c, 0xa6, 0xfb, 0xc4, 0x6c, 0xb0, 0xab, 0xda, 0xfb, 0xfd, 0x51, 0x18,
+ 0xc1, 0xc1, 0x22, 0x05, 0xf5, 0x2b, 0x10, 0xe4, 0x84, 0x1b, 0xa7, 0xdc,
+ 0xe6, 0xd0, 0xf5, 0x64, 0xa7, 0xc7, 0x6f, 0x1a, 0x85, 0xf8, 0xc4, 0xfb,
+ 0xbf, 0x17, 0x54, 0xaa, 0x11, 0x29, 0xfe, 0xda, 0x33, 0x9e, 0xfa, 0xb1,
+ 0x86, 0x9b, 0xf4, 0xcd, 0xf5, 0xe1, 0x3f, 0x7b, 0x3b, 0xad, 0xd9, 0xe2,
+ 0xc7, 0x6e, 0x4f, 0x1e, 0xa8, 0x13, 0x22, 0xa2, 0x7a, 0xcf, 0xe1, 0x8a,
+ 0x06, 0xf3, 0x28, 0x3a, 0xdc, 0xd3, 0x8c, 0x24, 0xa6, 0xe0, 0xd3, 0x5a,
+ 0x23, 0x21, 0x53, 0x02, 0x7d, 0x08, 0x30, 0xcb, 0xf1, 0x21, 0xca, 0x72,
+ 0x69, 0x49, 0x6e, 0x0f, 0xbc, 0x03, 0x7e, 0x0e, 0x60, 0x5d, 0x92, 0x08,
+ 0x3f, 0x04, 0x76, 0x62, 0x2d, 0x4b, 0xeb, 0x61, 0xaa, 0xe6, 0xcd, 0x2f,
+ 0x28, 0x24, 0xb0, 0xe8, 0xa5, 0xfe, 0x89, 0x90, 0xb8, 0xa4, 0x60, 0x6e,
+ 0x4c, 0x8c, 0x2f, 0xa3, 0xae, 0x72, 0xf2, 0x42, 0xb9, 0xc9, 0xa2, 0x6f,
+ 0x91, 0xbc, 0x75, 0x3b, 0x35, 0xb7, 0xe6, 0x24, 0xcb, 0x80, 0x8a, 0x34,
+ 0xfa, 0x9d, 0xf1, 0x7c, 0x88, 0x98, 0x09, 0x7b, 0x50, 0x56, 0xa5, 0x84,
+ 0x9c, 0x5f, 0x6c, 0x6e, 0x10, 0xfa, 0x95, 0xb6, 0xbf, 0xe4, 0xb0, 0x55,
+ 0x29, 0x3d, 0xe6, 0x3d, 0x14, 0xd7, 0x70, 0x17, 0xd8, 0xd3, 0xaa, 0xaf,
+ 0x4f, 0x15, 0x99, 0x63, 0xd0, 0x74, 0xfc, 0xb0, 0x6b, 0x66, 0x28, 0x02,
+ 0xd1, 0xbb, 0x01, 0x57, 0x02, 0xfe, 0x52, 0xe2, 0x0b, 0xbd, 0x8c, 0x0a,
+ 0x87, 0x8b, 0x60, 0xe9, 0x02, 0x03, 0x01, 0x00, 0x01, 0x02, 0x82, 0x01,
+ 0x01, 0x00, 0xaa, 0x19, 0x7b, 0x5a, 0x6d, 0x7a, 0x9f, 0xac, 0x35, 0x4b,
+ 0xc2, 0x74, 0xe7, 0xca, 0x9a, 0x09, 0x21, 0x68, 0x1a, 0xbc, 0x6c, 0x5f,
+ 0x29, 0x8e, 0xe6, 0x96, 0x84, 0x83, 0xfd, 0x00, 0x80, 0x5f, 0xa3, 0x09,
+ 0xc5, 0xc7, 0x40, 0x28, 0x98, 0x4d, 0xd6, 0xa8, 0xf6, 0x30, 0x52, 0xfa,
+ 0xb2, 0x1a, 0xcf, 0xfc, 0x54, 0x16, 0x6d, 0xa6, 0x80, 0xd6, 0xb7, 0xc5,
+ 0x58, 0x43, 0x36, 0x95, 0xae, 0x3c, 0x7b, 0x58, 0x3b, 0xb9, 0xa8, 0x5b,
+ 0x68, 0xb7, 0xc8, 0xc9, 0x27, 0xd8, 0x8a, 0x44, 0xe6, 0xeb, 0x89, 0x0b,
+ 0x49, 0x6d, 0x0d, 0x9e, 0xd9, 0x88, 0x05, 0xdd, 0x4d, 0x6f, 0xfa, 0x99,
+ 0x96, 0xeb, 0xa6, 0xaa, 0xaf, 0x37, 0x06, 0xe3, 0xa8, 0xff, 0xc5, 0xc4,
+ 0xa0, 0x13, 0x94, 0x98, 0xec, 0x76, 0x7b, 0xe6, 0x8d, 0x82, 0xd3, 0x3c,
+ 0xbc, 0x1e, 0x74, 0x9a, 0x38, 0xbf, 0xf4, 0x11, 0xbe, 0x07, 0x32, 0x2d,
+ 0x16, 0x2c, 0xf2, 0x5d, 0x24, 0x38, 0x70, 0xfb, 0x90, 0x8a, 0x38, 0xd6,
+ 0x17, 0xe1, 0x66, 0x92, 0x38, 0x06, 0x97, 0xb3, 0x07, 0xfd, 0x77, 0xe2,
+ 0xe7, 0x49, 0xae, 0x5a, 0xbc, 0xe5, 0xa8, 0xca, 0xe1, 0x0f, 0xb6, 0x4c,
+ 0x05, 0x73, 0x3f, 0x11, 0xd0, 0xf9, 0x1e, 0xba, 0x53, 0x48, 0xf5, 0xaf,
+ 0x28, 0x5b, 0xea, 0x12, 0x63, 0xbc, 0x84, 0xa7, 0x5f, 0x2e, 0x1d, 0x3e,
+ 0x02, 0x54, 0x58, 0xed, 0x2b, 0x42, 0xf9, 0xc6, 0x0c, 0xd4, 0x24, 0x77,
+ 0x1a, 0x2c, 0xbf, 0x75, 0x92, 0xf7, 0xcb, 0xd4, 0x58, 0x2f, 0x88, 0x2d,
+ 0xe8, 0x16, 0xca, 0xe5, 0x25, 0xe8, 0x5b, 0xbd, 0x53, 0x26, 0x23, 0xe0,
+ 0xa9, 0x35, 0x4d, 0xdb, 0x51, 0x85, 0x63, 0x20, 0xad, 0x61, 0xd2, 0x6d,
+ 0xbf, 0x01, 0x7d, 0x04, 0x44, 0x02, 0x96, 0x92, 0x36, 0x19, 0xed, 0xd1,
+ 0xd8, 0x16, 0x86, 0x06, 0xd4, 0x81, 0x02, 0x81, 0x81, 0x00, 0xe1, 0xa6,
+ 0xca, 0xb3, 0xef, 0xfe, 0x9f, 0xd6, 0xac, 0x58, 0x5c, 0x17, 0x88, 0xaf,
+ 0x4d, 0x85, 0x29, 0x50, 0x1f, 0x66, 0x90, 0x9b, 0x81, 0xb6, 0x82, 0x0d,
+ 0xc3, 0x5a, 0xa8, 0x8a, 0x2b, 0x7f, 0x58, 0x9b, 0x07, 0xe6, 0x64, 0xf7,
+ 0x1c, 0x77, 0x9d, 0x53, 0x97, 0xa0, 0x33, 0x14, 0x6e, 0x77, 0x1e, 0xe3,
+ 0x00, 0x0f, 0xb2, 0xb1, 0x69, 0x25, 0x3d, 0x63, 0x3c, 0xe1, 0xbb, 0x41,
+ 0x74, 0x97, 0x2d, 0x5e, 0x14, 0x79, 0x93, 0x38, 0x15, 0xbe, 0x52, 0x74,
+ 0x64, 0xc0, 0xfd, 0x22, 0x8e, 0xd7, 0xc9, 0xfb, 0x66, 0x55, 0xce, 0x5b,
+ 0x6a, 0x6f, 0x00, 0xed, 0x03, 0x7e, 0x4b, 0x9c, 0x4b, 0x8b, 0x3a, 0x50,
+ 0x65, 0x0d, 0x70, 0x9b, 0xdb, 0xf7, 0x1f, 0xd7, 0x66, 0x7a, 0xd1, 0x1e,
+ 0xa0, 0x8f, 0xe6, 0x03, 0x12, 0x18, 0x52, 0x25, 0x41, 0xa7, 0xb9, 0x8e,
+ 0x75, 0x63, 0x11, 0xd2, 0x63, 0xd7, 0x02, 0x81, 0x81, 0x00, 0xd0, 0xc7,
+ 0xe9, 0x97, 0x38, 0x33, 0x95, 0xbd, 0x18, 0xa5, 0x0a, 0x68, 0xab, 0xba,
+ 0x5e, 0x3e, 0x1f, 0x16, 0x86, 0xc0, 0x50, 0x09, 0xab, 0x52, 0xb7, 0x62,
+ 0x4e, 0x34, 0xb1, 0xc1, 0xd3, 0xb5, 0xf4, 0xe0, 0x04, 0x30, 0xa6, 0xdd,
+ 0x4a, 0xba, 0x7c, 0x59, 0xed, 0xd7, 0x76, 0xd3, 0x02, 0xe7, 0x05, 0x18,
+ 0x00, 0xdb, 0x65, 0xf2, 0x82, 0xe4, 0xfa, 0xbf, 0x9d, 0xad, 0x1a, 0x56,
+ 0x7b, 0x5e, 0xef, 0xff, 0x9b, 0xe5, 0x2f, 0x7c, 0xdd, 0x50, 0x53, 0x2b,
+ 0x6b, 0xc0, 0xac, 0x7b, 0x21, 0x8d, 0xc3, 0x39, 0xfe, 0xd0, 0x1a, 0xed,
+ 0xd1, 0xb6, 0x56, 0xda, 0x9e, 0x87, 0x9a, 0x6a, 0x69, 0x81, 0x29, 0x81,
+ 0x75, 0x69, 0xa6, 0x25, 0xc2, 0xf7, 0x5a, 0x94, 0x97, 0x6a, 0x7a, 0xf9,
+ 0x6c, 0xbe, 0x43, 0x76, 0x34, 0xba, 0x0c, 0x50, 0x6d, 0x22, 0xe8, 0xa6,
+ 0x9c, 0x80, 0x62, 0x87, 0xc9, 0x3f, 0x02, 0x81, 0x80, 0x78, 0xaf, 0x47,
+ 0x1c, 0x63, 0x90, 0x30, 0x16, 0x95, 0x88, 0x90, 0x80, 0x79, 0xb7, 0x20,
+ 0x63, 0xc6, 0xcb, 0xb6, 0x6f, 0x99, 0x89, 0xc2, 0x1f, 0x45, 0x81, 0x6c,
+ 0xe9, 0x10, 0xd9, 0x0d, 0x18, 0x87, 0xe0, 0x2a, 0xa2, 0x7b, 0x7f, 0x7a,
+ 0x77, 0x32, 0xea, 0xa1, 0x5e, 0xa9, 0xd3, 0x14, 0x9d, 0x9b, 0x24, 0x57,
+ 0x45, 0x0e, 0x12, 0x3a, 0xa5, 0x13, 0x26, 0xff, 0x49, 0xcf, 0x67, 0xdb,
+ 0x9e, 0x7b, 0x42, 0x24, 0xfb, 0x3c, 0xd4, 0xb3, 0x34, 0x5e, 0x4f, 0x28,
+ 0x0f, 0xdb, 0x92, 0xdf, 0x08, 0xe4, 0x5b, 0x13, 0xc9, 0x72, 0x9b, 0x8b,
+ 0xda, 0x20, 0x89, 0xa2, 0xe3, 0xaa, 0x36, 0xc6, 0x64, 0x89, 0x64, 0xb4,
+ 0x17, 0x33, 0x11, 0xf8, 0xdc, 0x3b, 0xe8, 0x6d, 0x43, 0xe4, 0x92, 0x57,
+ 0xd7, 0x7e, 0x72, 0x47, 0xfc, 0x3f, 0xfa, 0xf3, 0x19, 0x6c, 0x71, 0x97,
+ 0xb0, 0xcb, 0xb8, 0x55, 0x73, 0x02, 0x81, 0x81, 0x00, 0x98, 0xf1, 0xfa,
+ 0x73, 0x67, 0x1e, 0x93, 0x11, 0x45, 0xde, 0x91, 0xb3, 0x80, 0x2a, 0x35,
+ 0x23, 0xf9, 0x0e, 0x3d, 0x84, 0xe0, 0x9d, 0x54, 0xbe, 0x71, 0xcd, 0x38,
+ 0x51, 0x6d, 0xee, 0xfa, 0x33, 0x0f, 0xc2, 0x94, 0x0f, 0x38, 0x0e, 0x60,
+ 0xd2, 0x20, 0x8a, 0x98, 0xac, 0x01, 0x46, 0x2f, 0x98, 0x21, 0xa9, 0x25,
+ 0xe7, 0x93, 0xd5, 0x86, 0x82, 0x4c, 0x16, 0xd7, 0x61, 0x9a, 0x2b, 0xc4,
+ 0x91, 0x15, 0xec, 0x00, 0xbe, 0x72, 0x7d, 0x5c, 0x7b, 0x9d, 0x91, 0xef,
+ 0x8b, 0xe4, 0x4f, 0x07, 0x93, 0x9c, 0x72, 0xfd, 0xf2, 0x61, 0xe7, 0xda,
+ 0x7b, 0x63, 0x41, 0x20, 0x65, 0x62, 0x7f, 0x95, 0xee, 0xa3, 0x03, 0x4d,
+ 0x8a, 0x29, 0xc6, 0xfb, 0xfe, 0xcc, 0x82, 0x92, 0x31, 0xd5, 0x08, 0xa7,
+ 0xda, 0xf1, 0xfc, 0xc4, 0x3f, 0x8f, 0x09, 0xd4, 0x09, 0x80, 0xb9, 0x9d,
+ 0x68, 0x87, 0xc5, 0xc5, 0x6d, 0x02, 0x81, 0x80, 0x1f, 0xd9, 0x20, 0xde,
+ 0xba, 0xcd, 0x63, 0x34, 0x4f, 0x9f, 0xbb, 0x05, 0x0a, 0x8d, 0x20, 0xe1,
+ 0x66, 0x41, 0x2f, 0xae, 0xc7, 0xfa, 0x5d, 0xfd, 0xb7, 0x2a, 0x0f, 0xa6,
+ 0x6d, 0xf3, 0xad, 0x65, 0x54, 0x75, 0x2c, 0x26, 0x1e, 0xac, 0x1f, 0x24,
+ 0x4c, 0x83, 0xe3, 0x28, 0x08, 0x60, 0x74, 0xfe, 0xa9, 0x53, 0x36, 0x1e,
+ 0xb3, 0x39, 0x9d, 0xe7, 0x49, 0x03, 0x66, 0x61, 0xe8, 0xd4, 0xf4, 0xd8,
+ 0x65, 0x57, 0x01, 0xed, 0xaa, 0x7b, 0x6b, 0x04, 0xa2, 0x5f, 0xe1, 0x67,
+ 0xe6, 0x06, 0x7c, 0x84, 0x2a, 0x7d, 0x53, 0x03, 0x1c, 0x9c, 0x82, 0x08,
+ 0x37, 0x07, 0xaf, 0x77, 0xe0, 0x99, 0x69, 0xce, 0x01, 0x5a, 0x85, 0x4b,
+ 0x27, 0xb9, 0xb2, 0x20, 0x8c, 0xa5, 0xb9, 0x42, 0x2f, 0xad, 0x56, 0xdd,
+ 0xb9, 0x0d, 0x23, 0x05, 0x53, 0x5a, 0x26, 0x3a, 0xe2, 0x17, 0x58, 0x79,
+ 0x96, 0x8c, 0x5a, 0x05};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_RECEIVER_PRIVATE_KEY_DER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc
index 8a58870eeea..d87bbdf396f 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc
@@ -11,19 +11,14 @@
#include "cast/standalone_receiver/avcodec_glue.h"
#include "util/big_endian.h"
#include "util/logging.h"
+#include "util/trace_logging.h"
using std::chrono::duration_cast;
using std::chrono::milliseconds;
using std::chrono::seconds;
-using openscreen::Error;
-using openscreen::ErrorOr;
-using openscreen::platform::Clock;
-using openscreen::platform::ClockNowFunctionPtr;
-using openscreen::platform::TaskRunner;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -44,6 +39,7 @@ void InterleaveAudioSamples(const uint8_t* const planes[],
int num_channels,
int num_samples,
uint8_t* interleaved) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
// Note: This could be optimized with SIMD intrinsics for much better
// performance.
auto* dest = reinterpret_cast<Element*>(interleaved);
@@ -61,10 +57,12 @@ void InterleaveAudioSamples(const uint8_t* const planes[],
SDLAudioPlayer::SDLAudioPlayer(ClockNowFunctionPtr now_function,
TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
std::function<void()> error_callback)
: SDLPlayerBase(now_function,
task_runner,
receiver,
+ codec_name,
std::move(error_callback),
kAudioMediaType) {}
@@ -76,6 +74,7 @@ SDLAudioPlayer::~SDLAudioPlayer() {
ErrorOr<Clock::time_point> SDLAudioPlayer::RenderNextFrame(
const SDLPlayerBase::PresentableFrame& next_frame) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
OSP_DCHECK(next_frame.decoded_frame);
const AVFrame& frame = *next_frame.decoded_frame;
@@ -171,6 +170,7 @@ bool SDLAudioPlayer::RenderWhileIdle(const PresentableFrame* frame) {
}
void SDLAudioPlayer::Present() {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
if (state() != kScheduledToPresent) {
// In all other states, just do nothing. The SDL audio buffer will underrun
// and result in silence.
@@ -207,8 +207,6 @@ void SDLAudioPlayer::Present() {
// static
SDL_AudioFormat SDLAudioPlayer::GetSDLAudioFormat(AVSampleFormat format) {
- using openscreen::IsBigEndianArchitecture;
-
switch (format) {
case AV_SAMPLE_FMT_U8P:
case AV_SAMPLE_FMT_U8:
@@ -234,5 +232,5 @@ SDL_AudioFormat SDLAudioPlayer::GetSDLAudioFormat(AVSampleFormat format) {
return kSDLAudioFormatUnknown;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h
index d7a4b0ea3e7..8788d1f5c05 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.h
@@ -5,27 +5,31 @@
#ifndef CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_
#define CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_
+#include <string>
+#include <vector>
+
#include "cast/standalone_receiver/sdl_player_base.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Consumes frames from a Receiver, decodes them, and renders them to an
// internally-owned SDL audio device.
-class SDLAudioPlayer : public SDLPlayerBase {
+class SDLAudioPlayer final : public SDLPlayerBase {
public:
// |error_callback| is run only if a fatal error occurs, at which point the
// player has halted and set |error_status()|.
- SDLAudioPlayer(openscreen::platform::ClockNowFunctionPtr now_function,
- openscreen::platform::TaskRunner* task_runner,
+ SDLAudioPlayer(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
std::function<void()> error_callback);
~SDLAudioPlayer() final;
private:
// SDLPlayerBase implementation.
- openscreen::ErrorOr<openscreen::platform::Clock::time_point> RenderNextFrame(
+ ErrorOr<Clock::time_point> RenderNextFrame(
const SDLPlayerBase::PresentableFrame& frame) final;
bool RenderWhileIdle(const SDLPlayerBase::PresentableFrame* frame) final;
void Present() final;
@@ -38,7 +42,7 @@ class SDLAudioPlayer : public SDLPlayerBase {
// The amount of time before a target presentation time to call Present(), to
// account for audio buffering (the latency until samples reach the hardware).
- openscreen::platform::Clock::duration approximate_lead_time_{};
+ Clock::duration approximate_lead_time_{};
// When the decoder provides planar data, this buffer is used for storing the
// interleaved conversion.
@@ -54,7 +58,7 @@ class SDLAudioPlayer : public SDLPlayerBase {
SDL_AudioSpec device_spec_{};
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_SDL_AUDIO_PLAYER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc
index 4d12b0579d0..a77c1bf29c9 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.cc
@@ -8,18 +8,15 @@
#include "platform/api/time.h"
#include "util/logging.h"
-using openscreen::platform::Clock;
-using openscreen::platform::TaskRunner;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
SDLEventLoopProcessor::SDLEventLoopProcessor(
TaskRunner* task_runner,
std::function<void()> quit_callback)
: alarm_(&Clock::now, task_runner),
quit_callback_(std::move(quit_callback)) {
- alarm_.Schedule([this] { ProcessPendingEvents(); }, {});
+ alarm_.Schedule([this] { ProcessPendingEvents(); }, Alarm::kImmediately);
}
SDLEventLoopProcessor::~SDLEventLoopProcessor() = default;
@@ -38,9 +35,8 @@ void SDLEventLoopProcessor::ProcessPendingEvents() {
// Schedule a task to come back and process more pending events.
constexpr auto kEventPollPeriod = std::chrono::milliseconds(10);
- alarm_.Schedule([this] { ProcessPendingEvents(); },
- Clock::now() + kEventPollPeriod);
+ alarm_.ScheduleFromNow([this] { ProcessPendingEvents(); }, kEventPollPeriod);
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h
index 2900bf57557..59a3a02066c 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_glue.h
@@ -18,13 +18,10 @@
#include "util/alarm.h"
namespace openscreen {
-namespace platform {
+
class TaskRunner;
-} // namespace platform
-} // namespace openscreen
namespace cast {
-namespace streaming {
template <uint32_t subsystem>
class ScopedSDLSubSystem {
@@ -65,18 +62,18 @@ DEFINE_SDL_UNIQUE_PTR(Texture);
// event is received.
class SDLEventLoopProcessor {
public:
- SDLEventLoopProcessor(openscreen::platform::TaskRunner* task_runner,
+ SDLEventLoopProcessor(TaskRunner* task_runner,
std::function<void()> quit_callback);
~SDLEventLoopProcessor();
private:
void ProcessPendingEvents();
- openscreen::Alarm alarm_;
+ Alarm alarm_;
std::function<void()> quit_callback_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_SDL_GLUE_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc
index 0c5e6c9c787..2d2f7b61db8 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc
@@ -12,28 +12,25 @@
#include "cast/streaming/encoded_frame.h"
#include "util/big_endian.h"
#include "util/logging.h"
+#include "util/trace_logging.h"
using std::chrono::duration_cast;
using std::chrono::milliseconds;
-using openscreen::Error;
-using openscreen::ErrorOr;
-using openscreen::platform::Clock;
-using openscreen::platform::ClockNowFunctionPtr;
-using openscreen::platform::TaskRunner;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
SDLPlayerBase::SDLPlayerBase(ClockNowFunctionPtr now_function,
TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
std::function<void()> error_callback,
const char* media_type)
: now_(now_function),
receiver_(receiver),
error_callback_(std::move(error_callback)),
media_type_(media_type),
+ decoder_(codec_name),
decode_alarm_(now_, task_runner),
render_alarm_(now_, task_runner),
presentation_alarm_(now_, task_runner) {
@@ -69,35 +66,16 @@ void SDLPlayerBase::OnFatalError(std::string message) {
}
}
-void SDLPlayerBase::OnFramesReady(int buffer_size) {
- // Do not consume anything if there are too many frames in the pipeline
- // already.
- if (static_cast<int>(frames_to_render_.size()) > kMaxFramesInPipeline) {
- return;
- }
-
- // Consume the next frame.
- const Clock::time_point start_time = now_();
- buffer_.Resize(buffer_size);
- EncodedFrame frame = receiver_->ConsumeNextFrame(buffer_.GetSpan());
-
- // Create the tracking state for the frame in the player pipeline.
- OSP_DCHECK_EQ(frames_to_render_.count(frame.frame_id), 0);
- PendingFrame& pending_frame = frames_to_render_[frame.frame_id];
- pending_frame.start_time = start_time;
-
- // Determine the presentation time of the frame. Ideally, this will occur
- // based on the time progression of the media, given by the RTP timestamps.
- // However, if this falls too far out-of-sync with the system reference clock,
- // re-syrchronize, possibly causing user-visible "jank."
+Clock::time_point SDLPlayerBase::ResyncAndDeterminePresentationTime(
+ const EncodedFrame& frame) {
constexpr auto kMaxPlayoutDrift = milliseconds(100);
const auto media_time_since_last_sync =
(frame.rtp_timestamp - last_sync_rtp_timestamp_)
.ToDuration<Clock::duration>(receiver_->rtp_timebase());
- pending_frame.presentation_time =
+ Clock::time_point presentation_time =
last_sync_reference_time_ + media_time_since_last_sync;
- const auto drift = duration_cast<milliseconds>(
- frame.reference_time - pending_frame.presentation_time);
+ const auto drift =
+ duration_cast<milliseconds>(frame.reference_time - presentation_time);
if (drift > kMaxPlayoutDrift || drift < -kMaxPlayoutDrift) {
// Only log if not the very first frame.
OSP_LOG_IF(INFO, frame.frame_id != FrameId::first())
@@ -109,8 +87,30 @@ void SDLPlayerBase::OnFramesReady(int buffer_size) {
// back into sync over several frames.
last_sync_rtp_timestamp_ = frame.rtp_timestamp;
last_sync_reference_time_ = frame.reference_time;
- pending_frame.presentation_time = frame.reference_time;
+ presentation_time = frame.reference_time;
}
+ return presentation_time;
+}
+
+void SDLPlayerBase::OnFramesReady(int buffer_size) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
+ // Do not consume anything if there are too many frames in the pipeline
+ // already.
+ if (static_cast<int>(frames_to_render_.size()) > kMaxFramesInPipeline) {
+ return;
+ }
+
+ // Consume the next frame.
+ const Clock::time_point start_time = now_();
+ buffer_.Resize(buffer_size);
+ EncodedFrame frame = receiver_->ConsumeNextFrame(buffer_.GetSpan());
+
+ // Create the tracking state for the frame in the player pipeline.
+ OSP_DCHECK_EQ(frames_to_render_.count(frame.frame_id), 0);
+ PendingFrame& pending_frame = frames_to_render_[frame.frame_id];
+ pending_frame.start_time = start_time;
+
+ pending_frame.presentation_time = ResyncAndDeterminePresentationTime(frame);
// Start decoding the frame. This call may synchronously call back into the
// AVCodecDecoder::Client methods in this class.
@@ -118,6 +118,7 @@ void SDLPlayerBase::OnFramesReady(int buffer_size) {
}
void SDLPlayerBase::OnFrameDecoded(FrameId frame_id, const AVFrame& frame) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
const auto it = frames_to_render_.find(frame_id);
if (it == frames_to_render_.end()) {
return;
@@ -142,6 +143,7 @@ void SDLPlayerBase::OnDecodeError(FrameId frame_id, std::string message) {
}
void SDLPlayerBase::RenderAndSchedulePresentation() {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
// If something has already been scheduled to present at an exact time point,
// don't render anything new yet.
if (state_ == kScheduledToPresent) {
@@ -159,12 +161,12 @@ void SDLPlayerBase::RenderAndSchedulePresentation() {
// The interval here, is "lengthy" from the program's perspective, but
// reasonably "snappy" from the user's perspective.
constexpr auto kIdlePresentInterval = milliseconds(250);
- presentation_alarm_.Schedule(
+ presentation_alarm_.ScheduleFromNow(
[this] {
Present();
ResumeRendering();
},
- now_() + kIdlePresentInterval);
+ kIdlePresentInterval);
}
return;
}
@@ -227,11 +229,12 @@ void SDLPlayerBase::ResumeDecoding() {
OnFramesReady(buffer_size);
}
},
- {});
+ Alarm::kImmediately);
}
void SDLPlayerBase::ResumeRendering() {
- render_alarm_.Schedule([this] { RenderAndSchedulePresentation(); }, {});
+ render_alarm_.Schedule([this] { RenderAndSchedulePresentation(); },
+ Alarm::kImmediately);
}
// static
@@ -250,5 +253,5 @@ SDLPlayerBase::PendingFrame::PendingFrame(PendingFrame&&) noexcept = default;
SDLPlayerBase::PendingFrame& SDLPlayerBase::PendingFrame::operator=(
PendingFrame&&) noexcept = default;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h
index 34fb278e258..7338edab1b0 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.h
@@ -18,8 +18,8 @@
#include "platform/api/time.h"
#include "platform/base/error.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Common base class that consumes frames from a Receiver, decodes them, and
// plays them out via the appropriate SDL subsystem. Subclasses implement the
@@ -29,7 +29,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
~SDLPlayerBase() override;
// Returns OK unless a fatal error has occurred.
- const openscreen::Error& error_status() const { return error_status_; }
+ const Error& error_status() const { return error_status_; }
protected:
// Current player state, which is used to determine what to render/present,
@@ -43,7 +43,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// A decoded frame and its target presentation time.
struct PresentableFrame {
- openscreen::platform::Clock::time_point presentation_time;
+ Clock::time_point presentation_time;
AVFrameUniquePtr decoded_frame;
PresentableFrame();
@@ -55,9 +55,10 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// |error_callback| is run only if a fatal error occurs, at which point the
// player has halted and set |error_status()|. |media_type| should be "audio"
// or "video" (only used when logging).
- SDLPlayerBase(openscreen::platform::ClockNowFunctionPtr now_function,
- openscreen::platform::TaskRunner* task_runner,
+ SDLPlayerBase(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
std::function<void()> error_callback,
const char* media_type);
@@ -68,8 +69,8 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
void OnFatalError(std::string message) final;
// Renders the |frame| and returns its [possibly adjusted] presentation time.
- virtual openscreen::ErrorOr<openscreen::platform::Clock::time_point>
- RenderNextFrame(const PresentableFrame& frame) = 0;
+ virtual ErrorOr<Clock::time_point> RenderNextFrame(
+ const PresentableFrame& frame) = 0;
// Called to render when the player has no new content, and returns true if a
// Present() is necessary. |frame| may be null, if it is not available. This
@@ -82,9 +83,25 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
virtual void Present() = 0;
private:
+ struct PendingFrame : public PresentableFrame {
+ Clock::time_point start_time;
+
+ PendingFrame();
+ ~PendingFrame();
+ PendingFrame(PendingFrame&& other) noexcept;
+ PendingFrame& operator=(PendingFrame&& other) noexcept;
+ };
+
// Receiver::Consumer implementation.
void OnFramesReady(int next_frame_buffer_size) final;
+ // Determine the presentation time of the frame. Ideally, this will occur
+ // based on the time progression of the media, given by the RTP timestamps.
+ // However, if this falls too far out-of-sync with the system reference clock,
+ // re-synchronize, possibly causing user-visible "jank."
+ Clock::time_point ResyncAndDeterminePresentationTime(
+ const EncodedFrame& frame);
+
// AVCodecDecoder::Client implementation. These are called-back from
// |decoder_| to provide results.
void OnFrameDecoded(FrameId frame_id, const AVFrame& frame) final;
@@ -108,13 +125,13 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// require rendering/presenting a different output.
void ResumeRendering();
- const openscreen::platform::ClockNowFunctionPtr now_;
+ const ClockNowFunctionPtr now_;
Receiver* const receiver_;
std::function<void()> error_callback_; // Run once by OnFatalError().
const char* const media_type_; // For logging only.
// Set to the error code that placed the player in a fatal error state.
- openscreen::Error error_status_;
+ Error error_status_;
// Current player state, which is used to determine what to render/present,
// and how frequently.
@@ -122,14 +139,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// Queue of frames currently being decoded and decoded frames awaiting
// rendering.
- struct PendingFrame : public PresentableFrame {
- openscreen::platform::Clock::time_point start_time;
- PendingFrame();
- ~PendingFrame();
- PendingFrame(PendingFrame&& other) noexcept;
- PendingFrame& operator=(PendingFrame&& other) noexcept;
- };
std::map<FrameId, PendingFrame> frames_to_render_;
// Buffer for holding EncodedFrame::data.
@@ -139,7 +149,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// whenever the media (RTP) timestamps drift too much away from the rate at
// which the local clock ticks. This is important for A/V synchronization.
RtpTimeTicks last_sync_rtp_timestamp_{};
- openscreen::platform::Clock::time_point last_sync_reference_time_{};
+ Clock::time_point last_sync_reference_time_{};
Decoder decoder_;
@@ -149,13 +159,13 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
// A cumulative moving average of recent single-frame processing times
// (consume + decode + render). This is passed to the Cast Receiver so that it
// can determine when to drop late frames.
- openscreen::platform::Clock::duration recent_processing_time_{};
+ Clock::duration recent_processing_time_{};
// Alarms that execute the various stages of the player pipeline at certain
// times.
- openscreen::Alarm decode_alarm_;
- openscreen::Alarm render_alarm_;
- openscreen::Alarm presentation_alarm_;
+ Alarm decode_alarm_;
+ Alarm render_alarm_;
+ Alarm presentation_alarm_;
// Maximum number of frames in the decode/render pipeline. This limit is about
// making sure the player uses resources efficiently: It is better for frames
@@ -164,7 +174,7 @@ class SDLPlayerBase : public Receiver::Consumer, public Decoder::Client {
static constexpr int kMaxFramesInPipeline = 8;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_SDL_PLAYER_BASE_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc
index 250a76c05fb..f0fff9fadf6 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.cc
@@ -8,15 +8,10 @@
#include "cast/standalone_receiver/avcodec_glue.h"
#include "util/logging.h"
+#include "util/trace_logging.h"
-using openscreen::Error;
-using openscreen::ErrorOr;
-using openscreen::platform::Clock;
-using openscreen::platform::ClockNowFunctionPtr;
-using openscreen::platform::TaskRunner;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
constexpr char kVideoMediaType[] = "video";
@@ -25,11 +20,13 @@ constexpr char kVideoMediaType[] = "video";
SDLVideoPlayer::SDLVideoPlayer(ClockNowFunctionPtr now_function,
TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
SDL_Renderer* renderer,
std::function<void()> error_callback)
: SDLPlayerBase(now_function,
task_runner,
receiver,
+ codec_name,
std::move(error_callback),
kVideoMediaType),
renderer_(renderer) {
@@ -40,6 +37,7 @@ SDLVideoPlayer::~SDLVideoPlayer() = default;
bool SDLVideoPlayer::RenderWhileIdle(
const SDLPlayerBase::PresentableFrame* frame) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
// Attempt to re-render the same content.
if (state() == kPresented && frame) {
const auto result = RenderNextFrame(*frame);
@@ -69,6 +67,7 @@ bool SDLVideoPlayer::RenderWhileIdle(
ErrorOr<Clock::time_point> SDLVideoPlayer::RenderNextFrame(
const SDLPlayerBase::PresentableFrame& frame) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
OSP_DCHECK(frame.decoded_frame);
const AVFrame& picture = *frame.decoded_frame;
@@ -163,6 +162,7 @@ ErrorOr<Clock::time_point> SDLVideoPlayer::RenderNextFrame(
}
void SDLVideoPlayer::Present() {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
SDL_RenderPresent(renderer_);
}
@@ -201,5 +201,5 @@ uint32_t SDLVideoPlayer::GetSDLPixelFormat(const AVFrame& picture) {
return SDL_PIXELFORMAT_UNKNOWN;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h
index 915881eadc2..24b3496ccc0 100644
--- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_video_player.h
@@ -5,20 +5,23 @@
#ifndef CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_
#define CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_
+#include <string>
+
#include "cast/standalone_receiver/sdl_player_base.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Consumes frames from a Receiver, decodes them, and renders them to a
// SDL_Renderer.
-class SDLVideoPlayer : public SDLPlayerBase {
+class SDLVideoPlayer final : public SDLPlayerBase {
public:
// |error_callback| is run only if a fatal error occurs, at which point the
// player has halted and set |error_status()|.
- SDLVideoPlayer(openscreen::platform::ClockNowFunctionPtr now_function,
- openscreen::platform::TaskRunner* task_runner,
+ SDLVideoPlayer(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
Receiver* receiver,
+ const std::string& codec_name,
SDL_Renderer* renderer,
std::function<void()> error_callback);
@@ -32,7 +35,7 @@ class SDLVideoPlayer : public SDLPlayerBase {
// Uploads the decoded picture in |frame| to a SDL texture and draws it using
// the SDL |renderer_|.
- openscreen::ErrorOr<openscreen::platform::Clock::time_point> RenderNextFrame(
+ ErrorOr<Clock::time_point> RenderNextFrame(
const SDLPlayerBase::PresentableFrame& frame) final;
// Makes whatever is currently drawn to the SDL |renderer_| be presented
@@ -50,7 +53,7 @@ class SDLVideoPlayer : public SDLPlayerBase {
SDLTextureUniquePtr texture_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STANDALONE_RECEIVER_SDL_VIDEO_PLAYER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc
new file mode 100644
index 00000000000..cf0990cd5ae
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.cc
@@ -0,0 +1,99 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_receiver/streaming_playback_controller.h"
+
+#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+#include "cast/standalone_receiver/sdl_audio_player.h"
+#include "cast/standalone_receiver/sdl_glue.h"
+#include "cast/standalone_receiver/sdl_video_player.h"
+#else
+#include "cast/standalone_receiver/dummy_player.h"
+#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+
+#include "util/trace_logging.h"
+
+namespace openscreen {
+namespace cast {
+
+#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+StreamingPlaybackController::StreamingPlaybackController(
+ TaskRunner* task_runner,
+ StreamingPlaybackController::Client* client)
+ : task_runner_(task_runner),
+ client_(client),
+ sdl_event_loop_(task_runner_, [this] {
+ client_->OnPlaybackError(this,
+ Error{Error::Code::kOperationCancelled,
+ std::string("SDL event loop closed.")});
+ }) {
+ OSP_DCHECK(task_runner_ != nullptr);
+ OSP_DCHECK(client_ != nullptr);
+ constexpr int kDefaultWindowWidth = 1280;
+ constexpr int kDefaultWindowHeight = 720;
+ window_ = MakeUniqueSDLWindow(
+ "Cast Streaming Receiver Demo",
+ SDL_WINDOWPOS_UNDEFINED /* initial X position */,
+ SDL_WINDOWPOS_UNDEFINED /* initial Y position */, kDefaultWindowWidth,
+ kDefaultWindowHeight, SDL_WINDOW_RESIZABLE);
+ OSP_CHECK(window_) << "Failed to create SDL window: " << SDL_GetError();
+ renderer_ = MakeUniqueSDLRenderer(window_.get(), -1, 0);
+ OSP_CHECK(renderer_) << "Failed to create SDL renderer: " << SDL_GetError();
+}
+#else
+StreamingPlaybackController::StreamingPlaybackController(
+ TaskRunner* task_runner,
+ StreamingPlaybackController::Client* client)
+ : task_runner_(task_runner), client_(client) {
+ OSP_DCHECK(task_runner_ != nullptr);
+ OSP_DCHECK(client_ != nullptr);
+}
+#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+
+// TODO(jophba): add async tracing to streaming implementation for exposing
+// how long the OFFER/ANSWER and receiver startup takes.
+void StreamingPlaybackController::OnNegotiated(
+ const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers receivers) {
+ TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver);
+#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+ if (receivers.audio) {
+ audio_player_ = std::make_unique<SDLAudioPlayer>(
+ &Clock::now, task_runner_, receivers.audio->receiver,
+ receivers.audio->selected_stream.stream.codec_name, [this] {
+ client_->OnPlaybackError(this, audio_player_->error_status());
+ });
+ }
+ if (receivers.video) {
+ video_player_ = std::make_unique<SDLVideoPlayer>(
+ &Clock::now, task_runner_, receivers.video->receiver,
+ receivers.video->selected_stream.stream.codec_name, renderer_.get(),
+ [this] {
+ client_->OnPlaybackError(this, video_player_->error_status());
+ });
+ }
+#else
+ if (receivers.audio) {
+ audio_player_ = std::make_unique<DummyPlayer>(receivers.audio->receiver);
+ }
+
+ if (receivers.video) {
+ video_player_ = std::make_unique<DummyPlayer>(receivers.video->receiver);
+ }
+#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+}
+
+void StreamingPlaybackController::OnConfiguredReceiversDestroyed(
+ const ReceiverSession* session) {
+ audio_player_.reset();
+ video_player_.reset();
+}
+
+void StreamingPlaybackController::OnError(const ReceiverSession* session,
+ Error error) {
+ client_->OnPlaybackError(this, error);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h
new file mode 100644
index 00000000000..72f26a78c70
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/streaming_playback_controller.h
@@ -0,0 +1,68 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_
+#define CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_
+
+#include <memory>
+
+#include "cast/streaming/receiver_session.h"
+#include "platform/impl/task_runner.h"
+
+#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+#include "cast/standalone_receiver/sdl_audio_player.h"
+#include "cast/standalone_receiver/sdl_glue.h"
+#include "cast/standalone_receiver/sdl_video_player.h"
+#else
+#include "cast/standalone_receiver/dummy_player.h"
+#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+
+namespace openscreen {
+namespace cast {
+
+class StreamingPlaybackController final : public ReceiverSession::Client {
+ public:
+ class Client {
+ public:
+ virtual void OnPlaybackError(StreamingPlaybackController* controller,
+ Error error) = 0;
+ };
+
+ StreamingPlaybackController(TaskRunner* task_runner,
+ StreamingPlaybackController::Client* client);
+
+ // ReceiverSession::Client overrides.
+ void OnNegotiated(const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers receivers) override;
+
+ void OnConfiguredReceiversDestroyed(const ReceiverSession* session) override;
+
+ void OnError(const ReceiverSession* session, Error error) override;
+
+ private:
+ TaskRunner* const task_runner_;
+ StreamingPlaybackController::Client* client_;
+
+#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+ // NOTE: member ordering is important, since the sub systems must be
+ // first-constructed, last-destroyed. Make sure any new SDL related
+ // members are added below the sub systems.
+ const ScopedSDLSubSystem<SDL_INIT_AUDIO> sdl_audio_sub_system_;
+ const ScopedSDLSubSystem<SDL_INIT_VIDEO> sdl_video_sub_system_;
+ const SDLEventLoopProcessor sdl_event_loop_;
+
+ SDLWindowUniquePtr window_;
+ SDLRendererUniquePtr renderer_;
+ std::unique_ptr<SDLAudioPlayer> audio_player_;
+ std::unique_ptr<SDLVideoPlayer> video_player_;
+#else
+ std::unique_ptr<DummyPlayer> audio_player_;
+ std::unique_ptr<DummyPlayer> video_player_;
+#endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS)
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_RECEIVER_STREAMING_PLAYBACK_CONTROLLER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn b/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn
new file mode 100644
index 00000000000..85e0f1b047b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/BUILD.gn
@@ -0,0 +1,44 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import("//build/config/external_libraries.gni")
+import("//build_overrides/build.gni")
+
+# Define the executable target only when the build is configured to use the
+# standalone platform implementation; since this is itself a standalone
+# application.
+if (!build_with_chromium) {
+ executable("cast_sender") {
+ sources = [
+ "main.cc",
+ ]
+ deps = [
+ "../../platform",
+ "../../util",
+ "../streaming:sender",
+ ]
+
+ defines = []
+ include_dirs = []
+ lib_dirs = []
+ libs = []
+ if (have_ffmpeg && have_libopus && have_libvpx) {
+ defines += [ "CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS" ]
+ sources += [
+ "ffmpeg_glue.cc",
+ "ffmpeg_glue.h",
+ "simulated_capturer.cc",
+ "simulated_capturer.h",
+ "streaming_opus_encoder.cc",
+ "streaming_opus_encoder.h",
+ "streaming_vp8_encoder.cc",
+ "streaming_vp8_encoder.h",
+ ]
+ include_dirs +=
+ ffmpeg_include_dirs + libopus_include_dirs + libvpx_include_dirs
+ lib_dirs += ffmpeg_lib_dirs + libopus_lib_dirs + libvpx_lib_dirs
+ libs += ffmpeg_libs + libopus_libs + libvpx_libs
+ }
+ }
+}
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS b/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS
new file mode 100644
index 00000000000..3074fec20e0
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/DEPS
@@ -0,0 +1,8 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+include_rules = [
+ '+cast',
+ '+platform/impl',
+]
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc
new file mode 100644
index 00000000000..c271f80b188
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.cc
@@ -0,0 +1,32 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_sender/ffmpeg_glue.h"
+
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+namespace internal {
+
+AVFormatContext* CreateAVFormatContextForFile(const char* path) {
+ AVFormatContext* format_context = nullptr;
+ int result = avformat_open_input(&format_context, path, nullptr, nullptr);
+ if (result < 0) {
+ OSP_LOG_ERROR << "Cannot open " << path << ": " << av_err2str(result);
+ return nullptr;
+ }
+ result = avformat_find_stream_info(format_context, nullptr);
+ if (result < 0) {
+ avformat_close_input(&format_context);
+ OSP_LOG_ERROR << "Cannot find stream info in " << path << ": "
+ << av_err2str(result);
+ return nullptr;
+ }
+ return format_context;
+}
+
+} // namespace internal
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h
new file mode 100644
index 00000000000..28084c0d7d8
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/ffmpeg_glue.h
@@ -0,0 +1,71 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_
+#define CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_
+
+extern "C" {
+#include <libavcodec/avcodec.h>
+#include <libavformat/avformat.h>
+#include <libavutil/channel_layout.h>
+#include <libavutil/common.h>
+#include <libavutil/imgutils.h>
+#include <libavutil/mathematics.h>
+#include <libavutil/pixfmt.h>
+#include <libavutil/samplefmt.h>
+#include <libswresample/swresample.h>
+}
+
+#include <memory>
+#include <utility>
+
+namespace openscreen {
+namespace cast {
+
+namespace internal {
+
+// Convenience allocator for a new AVFormatContext, given a file |path|. Returns
+// nullptr on error. Note: MakeUniqueAVFormatContext() is the public API.
+AVFormatContext* CreateAVFormatContextForFile(const char* path);
+
+} // namespace internal
+
+// Macro that, for an AVFoo, generates code for:
+//
+// using FooUniquePtr = std::unique_ptr<Foo, FooFreer>;
+// FooUniquePtr MakeUniqueFoo(...args...);
+#define DEFINE_AV_UNIQUE_PTR(name, create_func, free_func) \
+ namespace internal { \
+ struct name##Freer { \
+ void operator()(name* obj) const { \
+ if (obj) { \
+ free_func(&obj); \
+ } \
+ } \
+ }; \
+ } \
+ \
+ using name##UniquePtr = std::unique_ptr<name, internal::name##Freer>; \
+ \
+ template <typename... Args> \
+ name##UniquePtr MakeUnique##name(Args&&... args) { \
+ return name##UniquePtr(create_func(std::forward<Args>(args)...)); \
+ }
+
+DEFINE_AV_UNIQUE_PTR(AVFormatContext,
+ ::openscreen::cast::internal::CreateAVFormatContextForFile,
+ avformat_close_input);
+DEFINE_AV_UNIQUE_PTR(AVCodecContext,
+ avcodec_alloc_context3,
+ avcodec_free_context);
+DEFINE_AV_UNIQUE_PTR(AVPacket, av_packet_alloc, av_packet_free);
+DEFINE_AV_UNIQUE_PTR(AVFrame, av_frame_alloc, av_frame_free);
+DEFINE_AV_UNIQUE_PTR(SwrContext, swr_alloc, swr_free);
+
+#undef DEFINE_AV_UNIQUE_PTR
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_SENDER_FFMPEG_GLUE_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc
new file mode 100644
index 00000000000..5d6c1a2c007
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc
@@ -0,0 +1,480 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <getopt.h>
+
+#include <chrono> // NOLINT
+#include <cinttypes>
+#include <csignal>
+#include <cstdio>
+#include <cstring>
+#include <sstream>
+
+#include "cast/streaming/constants.h"
+#include "cast/streaming/environment.h"
+#include "cast/streaming/sender.h"
+#include "cast/streaming/sender_packet_router.h"
+#include "cast/streaming/session_config.h"
+#include "cast/streaming/ssrc.h"
+#include "platform/api/time.h"
+#include "platform/base/error.h"
+#include "platform/base/ip_address.h"
+#include "platform/impl/logging.h"
+#include "platform/impl/platform_client_posix.h"
+#include "platform/impl/task_runner.h"
+#include "platform/impl/text_trace_logging_platform.h"
+#include "util/alarm.h"
+
+#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS)
+#include "cast/standalone_sender/simulated_capturer.h"
+#include "cast/standalone_sender/streaming_opus_encoder.h"
+#include "cast/standalone_sender/streaming_vp8_encoder.h"
+#endif
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using std::chrono::duration_cast;
+using std::chrono::milliseconds;
+using std::chrono::seconds;
+
+////////////////////////////////////////////////////////////////////////////////
+// Sender Configuration
+//
+// The values defined here are constants that correspond to the standalone Cast
+// Receiver app. In a production environment, these should ABSOLUTELY NOT be
+// fixed! Instead a sender↔receiver OFFER/ANSWER exchange should establish them.
+
+// In a production environment, this would start-out at some initial value
+// appropriate to the networking environment, and then be adjusted by the
+// application as: 1) the TYPE of the content changes (interactive, low-latency
+// versus smooth, higher-latency buffered video watching); and 2) the networking
+// environment reliability changes.
+constexpr milliseconds kTargetPlayoutDelay = kDefaultTargetPlayoutDelay;
+
+const SessionConfig kSampleAudioAnswerConfig{
+ /* .sender_ssrc = */ 1,
+ /* .receiver_ssrc = */ 2,
+ /* .rtp_timebase = */ 48000,
+ /* .channels = */ 2,
+ /* .target_playout_delay */ kTargetPlayoutDelay,
+ /* .aes_secret_key = */
+ {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f},
+ /* .aes_iv_mask = */
+ {0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80, 0x70, 0x60, 0x50, 0x40,
+ 0x30, 0x20, 0x10, 0x00},
+};
+
+const SessionConfig kSampleVideoAnswerConfig{
+ /* .sender_ssrc = */ 50001,
+ /* .receiver_ssrc = */ 50002,
+ /* .rtp_timebase = */ static_cast<int>(kVideoTimebase::den),
+ /* .channels = */ 1,
+ /* .target_playout_delay */ kTargetPlayoutDelay,
+ /* .aes_secret_key = */
+ {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
+ 0x1c, 0x1d, 0x1e, 0x1f},
+ /* .aes_iv_mask = */
+ {0xf1, 0xe1, 0xd1, 0xc1, 0xb1, 0xa1, 0x91, 0x81, 0x71, 0x61, 0x51, 0x41,
+ 0x31, 0x21, 0x11, 0x01},
+};
+
+// End of Sender Configuration.
+////////////////////////////////////////////////////////////////////////////////
+
+// What is the minimum amount of bandwidth required?
+constexpr int kMinRequiredBitrate = 384 << 10; // 384 kbps.
+
+// What is the default maximum bitrate setting?
+constexpr int kDefaultMaxBitrate = 5 << 20; // 5 Mbps.
+
+#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS)
+
+// Above what available bandwidth should the high-quality audio bitrate be used?
+constexpr int kHighBandwidthThreshold = 5 << 20; // 5 Mbps.
+
+// How often should the file position (media timestamp) be updated on the
+// console?
+constexpr milliseconds kConsoleUpdateInterval{100};
+
+// How often should the congestion control logic re-evaluate the target encode
+// bitrates?
+constexpr milliseconds kCongestionCheckInterval{500};
+
+// Plays the media file at a given path over and over again, transcoding and
+// streaming its audio/video.
+class LoopingFileSender final : public SimulatedAudioCapturer::Client,
+ public SimulatedVideoCapturer::Client {
+ public:
+ LoopingFileSender(TaskRunner* task_runner,
+ const char* path,
+ const IPEndpoint& remote_endpoint,
+ int max_bitrate,
+ bool use_android_rtp_hack)
+ : env_(&Clock::now, task_runner, IPEndpoint{IPAddress(), 0}),
+ path_(path),
+ packet_router_(&env_),
+ max_bitrate_(max_bitrate),
+ audio_sender_(&env_,
+ &packet_router_,
+ kSampleAudioAnswerConfig,
+ use_android_rtp_hack
+ ? RtpPayloadType::kAudioHackForAndroidTV
+ : RtpPayloadType::kAudioOpus),
+ video_sender_(&env_,
+ &packet_router_,
+ kSampleVideoAnswerConfig,
+ use_android_rtp_hack
+ ? RtpPayloadType::kVideoHackForAndroidTV
+ : RtpPayloadType::kVideoVp8),
+ audio_encoder_(kSampleAudioAnswerConfig.channels,
+ StreamingOpusEncoder::kDefaultCastAudioFramesPerSecond,
+ &audio_sender_),
+ video_encoder_(StreamingVp8Encoder::Parameters{},
+ env_.task_runner(),
+ &video_sender_),
+ next_task_(env_.now_function(), env_.task_runner()),
+ console_update_task_(env_.now_function(), env_.task_runner()) {
+ env_.set_remote_endpoint(remote_endpoint);
+ OSP_LOG_INFO << "Streaming to " << remote_endpoint << "...";
+
+ if (use_android_rtp_hack) {
+ OSP_LOG_INFO << "Using RTP payload types for older Android TV receivers.";
+ }
+
+ OSP_LOG_INFO << "Max allowed media bitrate (audio + video) will be "
+ << max_bitrate_;
+ bandwidth_being_utilized_ = max_bitrate_ / 2;
+ UpdateEncoderBitrates();
+
+ next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately);
+ }
+
+ ~LoopingFileSender() final = default;
+
+ private:
+ void UpdateEncoderBitrates() {
+ if (bandwidth_being_utilized_ >= kHighBandwidthThreshold) {
+ audio_encoder_.UseHighQuality();
+ } else {
+ audio_encoder_.UseStandardQuality();
+ }
+ video_encoder_.SetTargetBitrate(bandwidth_being_utilized_ -
+ audio_encoder_.GetBitrate());
+ }
+
+ void ControlForNetworkCongestion() {
+ bandwidth_estimate_ = packet_router_.ComputeNetworkBandwidth();
+ if (bandwidth_estimate_ > 0) {
+ // Don't ever try to use *all* of the network bandwidth! However, don't go
+ // below the absolute minimum requirement either.
+ constexpr double kGoodNetworkCitizenFactor = 0.8;
+ const int usable_bandwidth = std::max<int>(
+ kGoodNetworkCitizenFactor * bandwidth_estimate_, kMinRequiredBitrate);
+
+ // See "congestion control" discussion in the class header comments for
+ // BandwidthEstimator.
+ if (usable_bandwidth > bandwidth_being_utilized_) {
+ constexpr double kConservativeIncrease = 1.1;
+ bandwidth_being_utilized_ =
+ std::min<int>(bandwidth_being_utilized_ * kConservativeIncrease,
+ usable_bandwidth);
+ } else {
+ bandwidth_being_utilized_ = usable_bandwidth;
+ }
+
+ // Repsect the user's maximum bitrate setting.
+ bandwidth_being_utilized_ =
+ std::min(bandwidth_being_utilized_, max_bitrate_);
+
+ UpdateEncoderBitrates();
+ } else {
+ // There is no current bandwidth estimate. So, nothing should be adjusted.
+ }
+
+ next_task_.ScheduleFromNow([this] { ControlForNetworkCongestion(); },
+ kCongestionCheckInterval);
+ }
+
+ void SendFileAgain() {
+ OSP_LOG_INFO << "Sending " << path_ << " (starts in one second)...";
+
+ OSP_DCHECK_EQ(num_capturers_running_, 0);
+ num_capturers_running_ = 2;
+ capture_start_time_ = latest_frame_time_ = env_.now() + seconds(1);
+ audio_capturer_.emplace(&env_, path_, audio_encoder_.num_channels(),
+ audio_encoder_.sample_rate(), capture_start_time_,
+ this);
+ video_capturer_.emplace(&env_, path_, capture_start_time_, this);
+
+ next_task_.ScheduleFromNow([this] { ControlForNetworkCongestion(); },
+ kCongestionCheckInterval);
+ console_update_task_.Schedule([this] { UpdateStatusOnConsole(); },
+ capture_start_time_);
+ }
+
+ void OnAudioData(const float* interleaved_samples,
+ int num_samples,
+ Clock::time_point capture_time) final {
+ latest_frame_time_ = std::max(capture_time, latest_frame_time_);
+ audio_encoder_.EncodeAndSend(interleaved_samples, num_samples,
+ capture_time);
+ }
+
+ void OnVideoFrame(const AVFrame& av_frame,
+ Clock::time_point capture_time) final {
+ latest_frame_time_ = std::max(capture_time, latest_frame_time_);
+ StreamingVp8Encoder::VideoFrame frame{};
+ frame.width = av_frame.width - av_frame.crop_left - av_frame.crop_right;
+ frame.height = av_frame.height - av_frame.crop_top - av_frame.crop_bottom;
+ frame.yuv_planes[0] = av_frame.data[0] + av_frame.crop_left +
+ av_frame.linesize[0] * av_frame.crop_top;
+ frame.yuv_planes[1] = av_frame.data[1] + av_frame.crop_left / 2 +
+ av_frame.linesize[1] * av_frame.crop_top / 2;
+ frame.yuv_planes[2] = av_frame.data[2] + av_frame.crop_left / 2 +
+ av_frame.linesize[2] * av_frame.crop_top / 2;
+ for (int i = 0; i < 3; ++i) {
+ frame.yuv_strides[i] = av_frame.linesize[i];
+ }
+ // TODO(miu): Add performance metrics visual overlay (based on Stats
+ // callback).
+ video_encoder_.EncodeAndSend(frame, capture_time, {});
+ }
+
+ void UpdateStatusOnConsole() {
+ const Clock::duration elapsed = latest_frame_time_ - capture_start_time_;
+ const auto seconds_part = duration_cast<seconds>(elapsed);
+ const auto millis_part =
+ duration_cast<milliseconds>(elapsed - seconds_part);
+ // The control codes here attempt to erase the current line the cursor is
+ // on, and then print out the updated status text. If the terminal does not
+ // support simple ANSI escape codes, the following will still work, but
+ // there might sometimes be old status lines not getting erased (i.e., just
+ // partially overwritten).
+ fprintf(stdout,
+ "\r\x1b[2K\rAt %01" PRId64
+ ".%03ds in file (est. network bandwidth: %d kbps). ",
+ static_cast<int64_t>(seconds_part.count()),
+ static_cast<int>(millis_part.count()), bandwidth_estimate_ / 1024);
+ fflush(stdout);
+
+ console_update_task_.ScheduleFromNow([this] { UpdateStatusOnConsole(); },
+ kConsoleUpdateInterval);
+ }
+
+ void OnEndOfFile(SimulatedCapturer* capturer) final {
+ OSP_LOG_INFO << "The " << ToTrackName(capturer)
+ << " capturer has reached the end of the media stream.";
+ --num_capturers_running_;
+ if (num_capturers_running_ == 0) {
+ console_update_task_.Cancel();
+ next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately);
+ }
+ }
+
+ void OnError(SimulatedCapturer* capturer, std::string message) final {
+ OSP_LOG_ERROR << "The " << ToTrackName(capturer)
+ << " has failed: " << message;
+ --num_capturers_running_;
+ // If both fail, the application just pauses. This accounts for things like
+ // "file not found" errors. However, if only one track fails, then keep
+ // going.
+ }
+
+ const char* ToTrackName(SimulatedCapturer* capturer) const {
+ const char* which;
+ if (capturer == &*audio_capturer_) {
+ which = "audio";
+ } else if (capturer == &*video_capturer_) {
+ which = "video";
+ } else {
+ OSP_NOTREACHED();
+ which = "";
+ }
+ return which;
+ }
+
+ // Holds the required injected dependencies (clock, task runner) used for Cast
+ // Streaming, and owns the UDP socket over which all communications occur with
+ // the remote's Receivers.
+ Environment env_;
+
+ // The path to the media file to stream over and over.
+ const char* const path_;
+
+ // The packet router allows both the Audio Sender and the Video Sender to
+ // share the same UDP socket.
+ SenderPacketRouter packet_router_;
+
+ const int max_bitrate_; // Passed by the user on the command line.
+ int bandwidth_estimate_ = 0;
+ int bandwidth_being_utilized_;
+
+ Sender audio_sender_;
+ Sender video_sender_;
+
+ StreamingOpusEncoder audio_encoder_;
+ StreamingVp8Encoder video_encoder_;
+
+ int num_capturers_running_ = 0;
+ Clock::time_point capture_start_time_{};
+ Clock::time_point latest_frame_time_{};
+ absl::optional<SimulatedAudioCapturer> audio_capturer_;
+ absl::optional<SimulatedVideoCapturer> video_capturer_;
+
+ Alarm next_task_;
+ Alarm console_update_task_;
+};
+
+#endif // defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS)
+
+IPEndpoint GetDefaultEndpoint() {
+ return IPEndpoint{IPAddress::kV4LoopbackAddress, kDefaultCastStreamingPort};
+}
+
+void LogUsage(const char* argv0) {
+ const char kUsageMessageFormat[] = R"(
+ usage: %s <options> <media_file>
+
+ --remote=addr[:port]
+ Specify the destination (e.g., 192.168.1.22:9999 or [::1]:12345).
+
+ Default if not set: %s
+
+ --max-bitrate=N
+ Specifies the maximum bits per second for the media streams.
+
+ Default if not set: %d.
+
+ --android-hack:
+ Use the wrong RTP payload types, for compatibility with older Android
+ TV receivers.
+
+ --tracing: Enable performance tracing logging.
+ )";
+ // TODO(https://crbug.com/openscreen/122): absl::StreamFormat() would be much
+ // cleaner here. For example, all the code here could be replaced with:
+ //
+ // OSP_LOG_ERROR << absl::StreamFormat(kUsageMessageFormat, argv0,
+ // absl::FromatStreamed(endpoint),
+ // kDefaultMaxBitrate);
+ std::string endpoint;
+ {
+ std::ostringstream oss;
+ oss << GetDefaultEndpoint();
+ endpoint = oss.str();
+ }
+ const int formatted_length_with_nul =
+ snprintf(nullptr, 0, kUsageMessageFormat, argv0, endpoint.c_str(),
+ kDefaultMaxBitrate) +
+ 1;
+ const std::unique_ptr<char[]> usage_cstr(new char[formatted_length_with_nul]);
+ snprintf(usage_cstr.get(), formatted_length_with_nul, kUsageMessageFormat,
+ argv0, endpoint.c_str(), kDefaultMaxBitrate);
+ OSP_LOG_ERROR << usage_cstr.get();
+}
+
+int StandaloneSenderMain(int argc, char* argv[]) {
+ SetLogLevel(LogLevel::kInfo);
+
+ const struct option argument_options[] = {
+ {"remote", required_argument, nullptr, 'r'},
+ {"max-bitrate", required_argument, nullptr, 'm'},
+ {"android-hack", no_argument, nullptr, 'a'},
+ {"tracing", no_argument, nullptr, 't'},
+ {"help", no_argument, nullptr, 'h'},
+ {nullptr, 0, nullptr, 0}};
+
+ IPEndpoint remote_endpoint = GetDefaultEndpoint();
+ [[maybe_unused]] bool use_android_rtp_hack = false;
+ [[maybe_unused]] int max_bitrate = kDefaultMaxBitrate;
+ std::unique_ptr<TextTraceLoggingPlatform> trace_logger;
+ int ch = -1;
+ while ((ch = getopt_long(argc, argv, "r:ath", argument_options, nullptr)) !=
+ -1) {
+ switch (ch) {
+ case 'r': {
+ const ErrorOr<IPEndpoint> parsed_endpoint = IPEndpoint::Parse(optarg);
+ if (parsed_endpoint.is_value()) {
+ remote_endpoint = parsed_endpoint.value();
+ } else {
+ const ErrorOr<IPAddress> parsed_address = IPAddress::Parse(optarg);
+ if (parsed_address.is_value()) {
+ remote_endpoint.address = parsed_address.value();
+ } else {
+ OSP_LOG_ERROR << "Invalid --remote specified: " << optarg;
+ LogUsage(argv[0]);
+ return 1;
+ }
+ }
+ break;
+ }
+ case 'm':
+ max_bitrate = atoi(optarg);
+ if (max_bitrate < kMinRequiredBitrate) {
+ OSP_LOG_ERROR << "Invalid --max-bitrate specified: " << optarg
+ << " is less than " << kMinRequiredBitrate;
+ LogUsage(argv[0]);
+ return 1;
+ }
+ break;
+ case 'a':
+ use_android_rtp_hack = true;
+ break;
+ case 't':
+ trace_logger = std::make_unique<TextTraceLoggingPlatform>();
+ break;
+ case 'h':
+ LogUsage(argv[0]);
+ return 1;
+ }
+ }
+
+ // The last command line argument must be the path to the file.
+ const char* path = nullptr;
+ if (optind == (argc - 1)) {
+ path = argv[optind];
+ }
+
+ if (!path || !remote_endpoint.port) {
+ LogUsage(argv[0]);
+ return 1;
+ }
+
+#if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS)
+
+ auto* const task_runner = new TaskRunnerImpl(&Clock::now);
+ PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50},
+ std::unique_ptr<TaskRunnerImpl>(task_runner));
+
+ {
+ LoopingFileSender file_sender(task_runner, path, remote_endpoint,
+ max_bitrate, use_android_rtp_hack);
+ // Run the event loop until SIGINT (e.g., CTRL-C at the console) or SIGTERM
+ // are signaled.
+ task_runner->RunUntilSignaled();
+ }
+
+ PlatformClientPosix::ShutDown();
+
+#else
+
+ OSP_LOG_INFO
+ << "It compiled! However, you need to configure the build to point to "
+ "external libraries in order to build a useful app.";
+
+#endif // defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS)
+
+ return 0;
+}
+
+} // namespace
+} // namespace cast
+} // namespace openscreen
+
+int main(int argc, char* argv[]) {
+ return openscreen::cast::StandaloneSenderMain(argc, argv);
+}
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc
new file mode 100644
index 00000000000..327c76f8376
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc
@@ -0,0 +1,378 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_sender/simulated_capturer.h"
+
+#include <algorithm>
+#include <chrono> // NOLINT
+#include <ratio> // NOLINT
+#include <sstream>
+#include <thread> // NOLINT
+
+#include "cast/streaming/environment.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+using openscreen::operator<<; // To pretty-print chrono values.
+
+namespace {
+// Threshold at which a warning about media pausing should be logged.
+constexpr std::chrono::seconds kPauseWarningThreshold{3};
+} // namespace
+
+SimulatedCapturer::Observer::~Observer() = default;
+
+SimulatedCapturer::SimulatedCapturer(Environment* environment,
+ const char* path,
+ AVMediaType media_type,
+ Clock::time_point start_time,
+ Observer* observer)
+ : format_context_(MakeUniqueAVFormatContext(path)),
+ media_type_(media_type),
+ start_time_(start_time),
+ observer_(observer),
+ packet_(MakeUniqueAVPacket()),
+ decoded_frame_(MakeUniqueAVFrame()),
+ next_task_(environment->now_function(), environment->task_runner()) {
+ OSP_DCHECK(observer_);
+
+ if (!format_context_) {
+ OnError("MakeUniqueAVFormatContext", AVERROR_UNKNOWN);
+ return; // Capturer is halted (unable to start).
+ }
+
+ AVCodec* codec;
+ const int stream_result = av_find_best_stream(format_context_.get(),
+ media_type_, -1, -1, &codec, 0);
+ if (stream_result < 0) {
+ OnError("av_find_best_stream", stream_result);
+ return; // Capturer is halted (unable to start).
+ }
+
+ stream_index_ = stream_result;
+ decoder_context_ = MakeUniqueAVCodecContext(codec);
+ if (!decoder_context_) {
+ OnError("MakeUniqueAVCodecContext", AVERROR_BUG);
+ return; // Capturer is halted (unable to start).
+ }
+ // This should also be 16 or less, since the encoder implementations emit
+ // warnings about too many encode threads. FFMPEG's VP8 implementation
+ // actually silently freezes if this is 10 or more. Thus, 8 is used for the
+ // max here, just to be safe.
+ decoder_context_->thread_count =
+ std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8);
+ const int params_result = avcodec_parameters_to_context(
+ decoder_context_.get(),
+ format_context_->streams[stream_index_]->codecpar);
+ if (params_result < 0) {
+ OnError("avcodec_parameters_to_context", params_result);
+ return; // Capturer is halted (unable to start).
+ }
+ SetAdditionalDecoderParameters(decoder_context_.get());
+
+ const int open_result = avcodec_open2(decoder_context_.get(), codec, nullptr);
+ if (open_result < 0) {
+ OnError("avcodec_open2", open_result);
+ return; // Capturer is halted (unable to start).
+ }
+
+ next_task_.Schedule([this] { StartDecodingNextFrame(); },
+ Alarm::kImmediately);
+}
+
+SimulatedCapturer::~SimulatedCapturer() = default;
+
+void SimulatedCapturer::SetAdditionalDecoderParameters(
+ AVCodecContext* decoder_context) {}
+
+absl::optional<Clock::duration> SimulatedCapturer::ProcessDecodedFrame(
+ const AVFrame& frame) {
+ return Clock::duration::zero();
+}
+
+void SimulatedCapturer::OnError(const char* function_name, int av_errnum) {
+ // Make a human-readable string from the libavcodec error.
+ std::ostringstream error;
+ error << "For " << av_get_media_type_string(media_type_) << ", "
+ << function_name << " returned error: " << av_err2str(av_errnum);
+
+ // Deliver the error notification in a separate task since this method might
+ // have been called from the constructor.
+ next_task_.Schedule(
+ [this, error_string = error.str()] {
+ observer_->OnError(this, error_string);
+ // Capturer is now halted.
+ },
+ Alarm::kImmediately);
+}
+
+// static
+Clock::duration SimulatedCapturer::ToApproximateClockDuration(
+ int64_t ticks,
+ const AVRational& time_base) {
+ return Clock::duration(av_rescale_q(
+ ticks, time_base,
+ AVRational{Clock::duration::period::num, Clock::duration::period::den}));
+}
+
+void SimulatedCapturer::StartDecodingNextFrame() {
+ const int read_frame_result =
+ av_read_frame(format_context_.get(), packet_.get());
+ if (read_frame_result < 0) {
+ if (read_frame_result == AVERROR_EOF) {
+ // Insert a "flush request" into the decoder's pipeline, which will
+ // signal an EOF in ConsumeNextDecodedFrame() later.
+ avcodec_send_packet(decoder_context_.get(), nullptr);
+ next_task_.Schedule([this] { ConsumeNextDecodedFrame(); },
+ Alarm::kImmediately);
+ } else {
+ // All other error codes are fatal.
+ OnError("av_read_frame", read_frame_result);
+ // Capturer is now halted.
+ }
+ return;
+ }
+
+ if (packet_->stream_index != stream_index_) {
+ av_packet_unref(packet_.get());
+ next_task_.Schedule([this] { StartDecodingNextFrame(); },
+ Alarm::kImmediately);
+ return;
+ }
+
+ const int send_packet_result =
+ avcodec_send_packet(decoder_context_.get(), packet_.get());
+ av_packet_unref(packet_.get());
+ if (send_packet_result < 0) {
+ // Note: AVERROR(EAGAIN) is also treated as fatal here because
+ // avcodec_receive_frame() will be called repeatedly until its result code
+ // indicates avcodec_send_packet() must be called again.
+ OnError("avcodec_send_packet", send_packet_result);
+ return; // Capturer is now halted.
+ }
+
+ next_task_.Schedule([this] { ConsumeNextDecodedFrame(); },
+ Alarm::kImmediately);
+}
+
+void SimulatedCapturer::ConsumeNextDecodedFrame() {
+ const int receive_frame_result =
+ avcodec_receive_frame(decoder_context_.get(), decoded_frame_.get());
+ if (receive_frame_result < 0) {
+ switch (receive_frame_result) {
+ case AVERROR(EAGAIN):
+ // This result code, according to libavcodec documentation, means more
+ // data should be fed into the decoder (e.g., interframe dependencies).
+ next_task_.Schedule([this] { StartDecodingNextFrame(); },
+ Alarm::kImmediately);
+ return;
+ case AVERROR_EOF:
+ observer_->OnEndOfFile(this);
+ return; // Capturer is now halted.
+ default:
+ OnError("avcodec_receive_frame", receive_frame_result);
+ return; // Capturer is now halted.
+ }
+ }
+
+ const Clock::duration frame_timestamp = ToApproximateClockDuration(
+ decoded_frame_->best_effort_timestamp,
+ format_context_->streams[stream_index_]->time_base);
+ if (last_frame_timestamp_) {
+ const Clock::duration delta = frame_timestamp - *last_frame_timestamp_;
+ if (delta <= Clock::duration::zero()) {
+ OSP_LOG_WARN << "Dropping " << av_get_media_type_string(media_type_)
+ << " frame with illegal timestamp (delta from last frame: "
+ << delta << "). Bad media file!";
+ av_frame_unref(decoded_frame_.get());
+ next_task_.Schedule([this] { ConsumeNextDecodedFrame(); },
+ Alarm::kImmediately);
+ return;
+ } else if (delta >= kPauseWarningThreshold) {
+ OSP_LOG_INFO << "For " << av_get_media_type_string(media_type_)
+ << ", encountered a media pause (" << delta
+ << ") in the file.";
+ }
+ }
+ last_frame_timestamp_ = frame_timestamp;
+
+ Clock::time_point capture_time = start_time_ + frame_timestamp;
+ const auto delay_adjustment_or_null = ProcessDecodedFrame(*decoded_frame_);
+ if (!delay_adjustment_or_null) {
+ av_frame_unref(decoded_frame_.get());
+ return; // Stop. Fatal error occurred.
+ }
+ capture_time += *delay_adjustment_or_null;
+
+ next_task_.Schedule(
+ [this, capture_time] {
+ DeliverDataToClient(*decoded_frame_, capture_time);
+ av_frame_unref(decoded_frame_.get());
+ ConsumeNextDecodedFrame();
+ },
+ capture_time);
+}
+
+SimulatedAudioCapturer::Client::~Client() = default;
+
+SimulatedAudioCapturer::SimulatedAudioCapturer(Environment* environment,
+ const char* path,
+ int num_channels,
+ int sample_rate,
+ Clock::time_point start_time,
+ Client* client)
+ : SimulatedCapturer(environment,
+ path,
+ AVMEDIA_TYPE_AUDIO,
+ start_time,
+ client),
+ num_channels_(num_channels),
+ sample_rate_(sample_rate),
+ client_(client),
+ resampler_(MakeUniqueSwrContext()) {
+ OSP_DCHECK_GT(num_channels_, 0);
+ OSP_DCHECK_GT(sample_rate_, 0);
+}
+
+SimulatedAudioCapturer::~SimulatedAudioCapturer() {
+ if (swr_is_initialized(resampler_.get())) {
+ swr_close(resampler_.get());
+ }
+}
+
+bool SimulatedAudioCapturer::EnsureResamplerIsInitializedFor(
+ const AVFrame& frame) {
+ if (swr_is_initialized(resampler_.get())) {
+ if (input_sample_format_ == static_cast<AVSampleFormat>(frame.format) &&
+ input_sample_rate_ == frame.sample_rate &&
+ input_channel_layout_ == frame.channel_layout) {
+ return true;
+ }
+
+ // Note: Usually, the resampler should be flushed before being destroyed.
+ // However, because of the way SimulatedAudioCapturer uses the API, only one
+ // audio sample should be dropped in the worst case. Log what's being
+ // dropped, just in case libswresample is behaving differently than
+ // expected.
+ const std::chrono::microseconds amount(
+ swr_get_delay(resampler_.get(), std::micro::den));
+ OSP_LOG_INFO << "Discarding " << amount
+ << " of audio from the resampler before re-init.";
+ }
+
+ input_sample_format_ = AV_SAMPLE_FMT_NONE;
+
+ // Create a fake output frame to hold the output audio parameters, because the
+ // resampler API is weird that way.
+ const auto fake_output_frame = MakeUniqueAVFrame();
+ fake_output_frame->channel_layout =
+ av_get_default_channel_layout(num_channels_);
+ fake_output_frame->format = AV_SAMPLE_FMT_FLT;
+ fake_output_frame->sample_rate = sample_rate_;
+ const int config_result =
+ swr_config_frame(resampler_.get(), fake_output_frame.get(), &frame);
+ if (config_result < 0) {
+ OnError("swr_config_frame", config_result);
+ return false; // Capturer is now halted.
+ }
+
+ const int init_result = swr_init(resampler_.get());
+ if (init_result < 0) {
+ OnError("swr_init", init_result);
+ return false; // Capturer is now halted.
+ }
+
+ input_sample_format_ = static_cast<AVSampleFormat>(frame.format);
+ input_sample_rate_ = frame.sample_rate;
+ input_channel_layout_ = frame.channel_layout;
+ return true;
+}
+
+absl::optional<Clock::duration> SimulatedAudioCapturer::ProcessDecodedFrame(
+ const AVFrame& frame) {
+ if (!EnsureResamplerIsInitializedFor(frame)) {
+ return absl::nullopt;
+ }
+
+ const int64_t num_leftover_input_samples =
+ swr_get_delay(resampler_.get(), input_sample_rate_);
+ OSP_DCHECK_GE(num_leftover_input_samples, 0);
+ const Clock::duration capture_time_adjustment = -ToApproximateClockDuration(
+ num_leftover_input_samples, AVRational{1, input_sample_rate_});
+
+ const int64_t num_output_samples_desired =
+ av_rescale_rnd(num_leftover_input_samples + frame.nb_samples,
+ sample_rate_, input_sample_rate_, AV_ROUND_ZERO);
+ OSP_DCHECK_GE(num_output_samples_desired, 0);
+ resampled_audio_.resize(num_channels_ * num_output_samples_desired);
+ uint8_t* output_argument[1] = {
+ reinterpret_cast<uint8_t*>(resampled_audio_.data())};
+ const int num_samples_converted_or_error = swr_convert(
+ resampler_.get(), output_argument, num_output_samples_desired,
+ const_cast<const uint8_t**>(frame.extended_data), frame.nb_samples);
+ if (num_samples_converted_or_error < 0) {
+ resampled_audio_.clear();
+ swr_close(resampler_.get());
+ OnError("swr_convert", num_samples_converted_or_error);
+ return absl::nullopt; // Capturer is now halted.
+ }
+ resampled_audio_.resize(num_channels_ * num_samples_converted_or_error);
+
+ return capture_time_adjustment;
+}
+
+void SimulatedAudioCapturer::DeliverDataToClient(
+ const AVFrame& unused,
+ Clock::time_point capture_time) {
+ if (resampled_audio_.empty()) {
+ return;
+ }
+ client_->OnAudioData(resampled_audio_.data(),
+ resampled_audio_.size() / num_channels_, capture_time);
+ resampled_audio_.clear();
+}
+
+SimulatedVideoCapturer::Client::~Client() = default;
+
+SimulatedVideoCapturer::SimulatedVideoCapturer(Environment* environment,
+ const char* path,
+ Clock::time_point start_time,
+ Client* client)
+ : SimulatedCapturer(environment,
+ path,
+ AVMEDIA_TYPE_VIDEO,
+ start_time,
+ client),
+ client_(client) {}
+
+SimulatedVideoCapturer::~SimulatedVideoCapturer() = default;
+
+void SimulatedVideoCapturer::SetAdditionalDecoderParameters(
+ AVCodecContext* decoder_context) {
+ // Require the I420 planar format for video.
+ decoder_context->get_format = [](struct AVCodecContext* s,
+ const enum AVPixelFormat* formats) {
+ // Return AV_PIX_FMT_YUV420P if it's in the provided list of supported
+ // formats. Otherwise, return AV_PIX_FMT_NONE.
+ //
+ // |formats| is a NONE-terminated array.
+ for (; *formats != AV_PIX_FMT_NONE; ++formats) {
+ if (*formats == AV_PIX_FMT_YUV420P) {
+ break;
+ }
+ }
+ return *formats;
+ };
+}
+
+void SimulatedVideoCapturer::DeliverDataToClient(
+ const AVFrame& frame,
+ Clock::time_point capture_time) {
+ client_->OnVideoFrame(frame, capture_time);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h
new file mode 100644
index 00000000000..8d32085aa73
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.h
@@ -0,0 +1,205 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_
+#define CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_
+
+#include <stdint.h>
+
+#include <string>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "cast/standalone_sender/ffmpeg_glue.h"
+#include "platform/api/time.h"
+#include "util/alarm.h"
+
+namespace openscreen {
+namespace cast {
+
+class Environment;
+
+// Simulates live media capture by demuxing, decoding, and emitting a stream of
+// frames from a file at normal (1X) speed. This is a base class containing
+// common functionality. Typical usage: Instantiate one SimulatedAudioCapturer
+// and one FileVideoStreamCapturer.
+class SimulatedCapturer {
+ public:
+ // Interface for receiving end-of-stream and fatal error notifications.
+ class Observer {
+ public:
+ // Called once the end of the file has been reached and the |capturer| has
+ // halted.
+ virtual void OnEndOfFile(SimulatedCapturer* capturer) = 0;
+
+ // Called if a non-recoverable error occurs and the |capturer| has halted.
+ virtual void OnError(SimulatedCapturer* capturer, std::string message) = 0;
+
+ protected:
+ virtual ~Observer();
+ };
+
+ protected:
+ SimulatedCapturer(Environment* environment,
+ const char* path,
+ AVMediaType media_type,
+ Clock::time_point start_time,
+ Observer* observer);
+
+ virtual ~SimulatedCapturer();
+
+ // Optionally overridden, to apply additional decoder context settings before
+ // avcodec_open2() is called.
+ virtual void SetAdditionalDecoderParameters(AVCodecContext* decoder_context);
+
+ // Performs any additional processing on the decoded frame (e.g., audio
+ // resampling), and returns any adjustments to the frame's capture time (e.g.,
+ // to account for any buffering). If a fatal error occurs, absl::nullopt is
+ // returned. The default implementation does nothing.
+ //
+ // Mutating the |decoded_frame| is not allowed. If a subclass implementation
+ // wants to deliver different data (e.g., resampled audio), it must stash the
+ // data itself for the next DeliverDataToClient() call.
+ virtual absl::optional<Clock::duration> ProcessDecodedFrame(
+ const AVFrame& decoded_frame);
+
+ // Delivers the decoded frame data to the client.
+ virtual void DeliverDataToClient(const AVFrame& decoded_frame,
+ Clock::time_point capture_time) = 0;
+
+ // Called when any transient or fatal error occurs, generating an Error and
+ // scheduling a task to notify the Observer of it soon.
+ void OnError(const char* what, int av_errnum);
+
+ // Converts the given FFMPEG tick count into an approximate Clock::duration.
+ static Clock::duration ToApproximateClockDuration(
+ int64_t ticks,
+ const AVRational& time_base);
+
+ private:
+ // Reads the next frame from the file, sends it to the decoder, and schedules
+ // a future ConsumeNextDecodedFrame() call to continue processing.
+ void StartDecodingNextFrame();
+
+ // Receives the next decoded frame and schedules media delivery to the client,
+ // and/or calls Observer::OnEndOfFile() if there are no more frames in the
+ // file.
+ void ConsumeNextDecodedFrame();
+
+ const AVFormatContextUniquePtr format_context_;
+ const AVMediaType media_type_; // Audio or Video.
+ const Clock::time_point start_time_;
+ Observer* const observer_;
+ const AVPacketUniquePtr packet_; // Decoder input buffer.
+ const AVFrameUniquePtr decoded_frame_; // Decoder output frame.
+ int stream_index_ = -1; // Selected stream from the file.
+ AVCodecContextUniquePtr decoder_context_;
+
+ // The last frame's stream timestamp. This is used to detect bad stream
+ // timestamps in the file.
+ absl::optional<Clock::duration> last_frame_timestamp_;
+
+ // Used to schedule the next task to execute and when it should execute. There
+ // is only ever one task scheduled/running at any time.
+ Alarm next_task_;
+};
+
+// Emits the primary audio stream from a file.
+class SimulatedAudioCapturer final : public SimulatedCapturer {
+ public:
+ class Client : public SimulatedCapturer::Observer {
+ public:
+ // Called to deliver more audio data as |interleaved_samples|, which
+ // contains |num_samples| tuples (i.e., multiply by the number of channels
+ // to determine the number of array elements). |capture_time| is used to
+ // synchronize the play-out of the first audio sample with respect to video
+ // frames.
+ virtual void OnAudioData(const float* interleaved_samples,
+ int num_samples,
+ Clock::time_point capture_time) = 0;
+
+ protected:
+ ~Client() override;
+ };
+
+ // Constructor: |num_channels| and |sample_rate| specify the required audio
+ // format. If necessary, audio from the file will be resampled to match the
+ // required format.
+ SimulatedAudioCapturer(Environment* environment,
+ const char* path,
+ int num_channels,
+ int sample_rate,
+ Clock::time_point start_time,
+ Client* client);
+
+ ~SimulatedAudioCapturer() final;
+
+ private:
+ // Examines the audio format of the given |frame|, and ensures the resampler
+ // is initialized to take that as input.
+ bool EnsureResamplerIsInitializedFor(const AVFrame& frame);
+
+ // Resamples the current |SimulatedCapturer::decoded_frame()| into the
+ // required output format/channels/rate. The result is stored in
+ // |resampled_audio_| for the next DeliverDataToClient() call.
+ absl::optional<Clock::duration> ProcessDecodedFrame(
+ const AVFrame& decoded_frame) final;
+
+ // Called at the moment Client::OnAudioData() should be called to pass the
+ // |resampled_audio_|.
+ void DeliverDataToClient(const AVFrame& decoded_frame,
+ Clock::time_point capture_time) final;
+
+ const int num_channels_; // Output number of channels.
+ const int sample_rate_; // Output sample rate.
+ Client* const client_;
+
+ const SwrContextUniquePtr resampler_;
+
+ // Current resampler input audio parameters.
+ AVSampleFormat input_sample_format_ = AV_SAMPLE_FMT_NONE;
+ int input_sample_rate_;
+ uint64_t input_channel_layout_; // Opaque value used by resampler library.
+
+ std::vector<float> resampled_audio_;
+};
+
+// Emits the primary video stream from a file.
+class SimulatedVideoCapturer final : public SimulatedCapturer {
+ public:
+ class Client : public SimulatedCapturer::Observer {
+ public:
+ // Called to deliver the next video |frame|, which is always in I420 format.
+ // |capture_time| is used to synchronize the play-out of the video frame
+ // with respect to the audio track.
+ virtual void OnVideoFrame(const AVFrame& frame,
+ Clock::time_point capture_time) = 0;
+
+ protected:
+ ~Client() override;
+ };
+
+ SimulatedVideoCapturer(Environment* environment,
+ const char* path,
+ Clock::time_point start_time,
+ Client* client);
+
+ ~SimulatedVideoCapturer() final;
+
+ private:
+ Client* const client_;
+
+ // Sets up the decoder to produce I420 format output.
+ void SetAdditionalDecoderParameters(AVCodecContext* decoder_context) final;
+
+ // Called at the moment Client::OnVideoFrame() should be called to provide the
+ // next video frame.
+ void DeliverDataToClient(const AVFrame& decoded_frame,
+ Clock::time_point capture_time) final;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_SENDER_SIMULATED_CAPTURER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc
new file mode 100644
index 00000000000..07aeaaefd13
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc
@@ -0,0 +1,224 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_sender/streaming_opus_encoder.h"
+
+#include <opus/opus.h>
+
+#include <algorithm>
+#include <chrono> // NOLINT
+
+namespace openscreen {
+namespace cast {
+
+using std::chrono::duration_cast;
+using std::chrono::microseconds;
+using std::chrono::seconds;
+
+using openscreen::operator<<; // To pretty-print chrono values.
+
+namespace {
+
+// The bitrate at which virtually all stereo audio can be encoded and decoded
+// without human-perceivable artifacts. Source:
+// https://wiki.hydrogenaud.io/index.php?title=Opus#Bitrate_performance
+constexpr opus_int32 kTransparentBitrate = 160000;
+
+// The maximum number of Cast audio frames the encoder may fall behind by before
+// skipping-ahead the RTP timestamps to compensate.
+constexpr int kMaxCastFramesBeforeSkip = 3;
+
+} // namespace
+
+StreamingOpusEncoder::StreamingOpusEncoder(int num_channels,
+ int cast_frames_per_second,
+ Sender* sender)
+ : num_channels_(num_channels),
+ sender_(sender),
+ samples_per_cast_frame_(sample_rate() / cast_frames_per_second),
+ approximate_cast_frame_duration_(
+ duration_cast<Clock::duration>(seconds(1)) / cast_frames_per_second),
+ encoder_storage_(new uint8_t[opus_encoder_get_size(num_channels_)]),
+ input_(new float[num_channels_ * samples_per_cast_frame_]),
+ output_(new uint8_t[kOpusMaxPayloadSize]) {
+ OSP_CHECK_GT(cast_frames_per_second, 0);
+ OSP_DCHECK(sender_);
+ OSP_CHECK_GT(samples_per_cast_frame_, 0);
+ OSP_CHECK_EQ(sample_rate() % cast_frames_per_second, 0);
+ OSP_CHECK(approximate_cast_frame_duration_ > Clock::duration::zero());
+
+ frame_.dependency = EncodedFrame::KEY_FRAME;
+
+ const auto init_result = opus_encoder_init(
+ encoder(), sample_rate(), num_channels_, OPUS_APPLICATION_AUDIO);
+ OSP_CHECK_EQ(init_result, OPUS_OK);
+
+ UseStandardQuality();
+}
+
+StreamingOpusEncoder::~StreamingOpusEncoder() = default;
+
+int StreamingOpusEncoder::GetBitrate() const {
+ opus_int32 bitrate;
+ const auto ctl_result =
+ opus_encoder_ctl(encoder(), OPUS_GET_BITRATE(&bitrate));
+ OSP_CHECK_EQ(ctl_result, OPUS_OK);
+ return bitrate;
+}
+
+void StreamingOpusEncoder::UseStandardQuality() {
+ const auto ctl_result =
+ opus_encoder_ctl(encoder(), OPUS_SET_BITRATE(OPUS_AUTO));
+ OSP_CHECK_EQ(ctl_result, OPUS_OK);
+ UpdateCodecDelay();
+}
+
+void StreamingOpusEncoder::UseHighQuality() {
+ // kTransparentBitrate assumes stereo audio. Scale it by the actual number of
+ // channels.
+ const opus_int32 bitrate = kTransparentBitrate * num_channels_ / 2;
+ const auto ctl_result =
+ opus_encoder_ctl(encoder(), OPUS_SET_BITRATE(bitrate));
+ OSP_CHECK_EQ(ctl_result, OPUS_OK);
+ UpdateCodecDelay();
+}
+
+void StreamingOpusEncoder::EncodeAndSend(const float* interleaved_samples,
+ int num_samples,
+ Clock::time_point reference_time) {
+ OSP_DCHECK(interleaved_samples);
+ OSP_DCHECK_GT(num_samples, 0);
+
+ ResolveTimestampsAndMaybeSkip(reference_time);
+
+ while (num_samples > 0) {
+ const int samples_copied =
+ FillInputBuffer(interleaved_samples, num_samples);
+ num_samples -= samples_copied;
+ interleaved_samples += num_channels_ * samples_copied;
+
+ if (num_samples_queued_ < samples_per_cast_frame_) {
+ return; // Not enough yet for a full Cast audio frame.
+ }
+
+ const opus_int32 packet_size_or_error =
+ opus_encode_float(encoder(), input_.get(), num_samples_queued_,
+ output_.get(), kOpusMaxPayloadSize);
+ num_samples_queued_ = 0;
+ if (packet_size_or_error < 0) {
+ OSP_LOG_FATAL << "AUDIO[" << sender_->ssrc()
+ << "] Error code from opus_encode_float(): "
+ << packet_size_or_error;
+ return;
+ }
+
+ frame_.frame_id = sender_->GetNextFrameId();
+ frame_.referenced_frame_id = frame_.frame_id;
+ // Note: It's possible for Opus to encode a zero byte packet. Send a Cast
+ // audio frame anyway, to represent the passage of silence and to send other
+ // stream metadata.
+ frame_.data = absl::Span<uint8_t>(output_.get(), packet_size_or_error);
+ last_sent_frame_reference_time_ = frame_.reference_time;
+ switch (sender_->EnqueueFrame(frame_)) {
+ case Sender::OK:
+ break;
+ case Sender::PAYLOAD_TOO_LARGE:
+ OSP_NOTREACHED(); // The Opus packet cannot possibly be too large.
+ break;
+ case Sender::REACHED_ID_SPAN_LIMIT:
+ OSP_LOG_WARN << "AUDIO[" << sender_->ssrc()
+ << "] Dropping: FrameId span limit reached.";
+ break;
+ case Sender::MAX_DURATION_IN_FLIGHT:
+ OSP_LOG_INFO << "AUDIO[" << sender_->ssrc()
+ << "] Dropping: In-flight duration would be too high.";
+ break;
+ }
+
+ frame_.rtp_timestamp += RtpTimeDelta::FromTicks(samples_per_cast_frame_);
+ frame_.reference_time += approximate_cast_frame_duration_;
+ }
+}
+
+void StreamingOpusEncoder::UpdateCodecDelay() {
+ opus_int32 lookahead = 0;
+ const auto ctl_result =
+ opus_encoder_ctl(encoder(), OPUS_GET_LOOKAHEAD(&lookahead));
+ OSP_CHECK_EQ(ctl_result, OPUS_OK);
+ codec_delay_ = RtpTimeDelta::FromTicks(lookahead).ToDuration<Clock::duration>(
+ sample_rate());
+}
+
+void StreamingOpusEncoder::ResolveTimestampsAndMaybeSkip(
+ Clock::time_point reference_time) {
+ // Back-track the reference time to account for the audio delay introduced by
+ // the codec.
+ reference_time -= codec_delay_;
+
+ // Special case: Nothing special for the first frame's timestamps.
+ if (start_time_ == Clock::time_point::min()) {
+ frame_.rtp_timestamp = RtpTimeTicks();
+ frame_.reference_time = start_time_ = reference_time;
+ last_sent_frame_reference_time_ =
+ reference_time - approximate_cast_frame_duration_;
+ return;
+ }
+
+ const RtpTimeTicks current_position =
+ frame_.rtp_timestamp + RtpTimeDelta::FromTicks(num_samples_queued_);
+ const RtpTimeTicks reference_position = RtpTimeTicks::FromTimeSinceOrigin(
+ reference_time - start_time_, sample_rate());
+ const RtpTimeDelta rtp_advancement = reference_position - current_position;
+ const RtpTimeDelta skip_threshold =
+ RtpTimeDelta::FromTicks(samples_per_cast_frame_) *
+ kMaxCastFramesBeforeSkip;
+ if (rtp_advancement > skip_threshold) {
+ OSP_LOG_WARN << "Detected audio gap "
+ << rtp_advancement.ToDuration<microseconds>(sample_rate())
+ << ", skipping ahead...";
+ num_samples_queued_ = 0;
+ frame_.rtp_timestamp = reference_position;
+ }
+
+ // Further back-track the reference time to account for the already-queued
+ // samples.
+ reference_time -= RtpTimeDelta::FromTicks(num_samples_queued_)
+ .ToDuration<Clock::duration>(sample_rate());
+
+ // Frame reference times must be monotonically increasing. A little noise in
+ // the negative direction is okay to cap-off. Log a warning if there's a
+ // bigger problem (at the source).
+ const Clock::time_point lower_bound =
+ last_sent_frame_reference_time_ +
+ RtpTimeDelta::FromTicks(1).ToDuration<Clock::duration>(sample_rate());
+ if (reference_time < lower_bound) {
+ const Clock::duration backwards_amount =
+ last_sent_frame_reference_time_ - reference_time;
+ OSP_LOG_IF(WARN, backwards_amount >= approximate_cast_frame_duration_)
+ << "Reference time went *backwards* too much (" << backwards_amount
+ << " in wrong direction). A/V sync may suffer at the Receiver!";
+ reference_time = lower_bound;
+ }
+
+ frame_.reference_time = reference_time;
+}
+
+int StreamingOpusEncoder::FillInputBuffer(const float* interleaved_samples,
+ int num_samples) {
+ const int samples_needed = samples_per_cast_frame_ - num_samples_queued_;
+ const int samples_to_copy = std::min(num_samples, samples_needed);
+ std::copy(interleaved_samples,
+ interleaved_samples + num_channels_ * samples_to_copy,
+ input_.get() + num_channels_ * num_samples_queued_);
+ num_samples_queued_ += samples_to_copy;
+ return samples_to_copy;
+}
+
+// static
+constexpr int StreamingOpusEncoder::kDefaultCastAudioFramesPerSecond;
+// static
+constexpr int StreamingOpusEncoder::kOpusMaxPayloadSize;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h
new file mode 100644
index 00000000000..0620e3a4588
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.h
@@ -0,0 +1,123 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_
+#define CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_
+
+#include <stdint.h>
+
+#include <memory>
+
+#include "cast/streaming/encoded_frame.h"
+#include "cast/streaming/sender.h"
+#include "platform/api/time.h"
+
+extern "C" {
+struct OpusEncoder;
+}
+
+namespace openscreen {
+namespace cast {
+
+// Wraps the libopus encoder so that the application can stream
+// interleaved-floats audio samples to a Sender. Either mono or stereo sound is
+// supported.
+class StreamingOpusEncoder {
+ public:
+ // Constructs the encoder for mono or stereo sound, dividing the stream of
+ // audio samples up into chunks as determined by the given
+ // |cast_frames_per_second|, and for EncodedFrame output to the given
+ // |sender|. The sample rate of the audio is assumed to be the Sender's fixed
+ // |rtp_timebase()|.
+ StreamingOpusEncoder(int num_channels,
+ int cast_frames_per_second,
+ Sender* sender);
+
+ ~StreamingOpusEncoder();
+
+ int num_channels() const { return num_channels_; }
+ int sample_rate() const { return sender_->rtp_timebase(); }
+
+ int GetBitrate() const;
+
+ // Sets the encoder back to its "AUTO" bitrate setting, for standard quality.
+ // This and UseHighQuality() may be called as often as needed as conditions
+ // change.
+ //
+ // Note: As of 2020-01-21, the encoder in "auto bitrate" mode would use a
+ // bitrate of 102kbps for 2-channel, 48 kHz audio and a 10 ms frame size.
+ void UseStandardQuality();
+
+ // Sets the encoder to use a high bitrate (virtually no artifacts), when
+ // plenty of network bandwidth is available. This and UseStandardQuality() may
+ // be called as often as needed as conditions change.
+ void UseHighQuality();
+
+ // Encode and send the given |interleaved_samples|, which contains
+ // |num_samples| tuples (i.e., multiply by the number of channels to determine
+ // the number of array elements). The audio is assumed to have been captured
+ // at the required |sample_rate()|. |reference_time| refers to the first
+ // sample.
+ void EncodeAndSend(const float* interleaved_samples,
+ int num_samples,
+ Clock::time_point reference_time);
+
+ static constexpr int kDefaultCastAudioFramesPerSecond =
+ 100; // 10 ms frame duration.
+
+ private:
+ OpusEncoder* encoder() const {
+ return reinterpret_cast<OpusEncoder*>(encoder_storage_.get());
+ }
+
+ // Updates the |codec_delay_| based on the current encoder settings.
+ void UpdateCodecDelay();
+
+ // Sets the next frame's reference time, accounting for codec buffering delay.
+ // Also, checks whether the reference time has drifted too far forwards, and
+ // skips if necessary.
+ void ResolveTimestampsAndMaybeSkip(Clock::time_point reference_time);
+
+ // Fills the input buffer as much as possible from the given source data, and
+ // returns the number of samples copied into the buffer.
+ int FillInputBuffer(const float* interleaved_samples, int num_samples);
+
+ const int num_channels_;
+ Sender* const sender_;
+ const int samples_per_cast_frame_;
+ const Clock::duration approximate_cast_frame_duration_;
+ const std::unique_ptr<uint8_t[]> encoder_storage_;
+ const std::unique_ptr<float[]> input_; // Interleaved audio samples.
+ const std::unique_ptr<uint8_t[]> output_; // Opus-encoded packet.
+
+ // The audio delay introduced by the codec.
+ Clock::duration codec_delay_{};
+
+ // The number of mono/stereo tuples currently queued in the |input_| buffer.
+ // Multiply by |num_channels_| to get the number of array elements.
+ int num_samples_queued_ = 0;
+
+ // The reference time of the first frame passed to EncodeAndSend(), offset by
+ // the codec delay.
+ Clock::time_point start_time_ = Clock::time_point::min();
+
+ // Initialized and used by EncodeAndSend() to hold the metadata and data
+ // pointer for each frame being sent.
+ EncodedFrame frame_;
+
+ // The |reference_time| for the last sent frame. This is used to check that
+ // the reference times are monotonically increasing. If they have [illegally]
+ // gone backwards too much, warnings will be logged.
+ Clock::time_point last_sent_frame_reference_time_;
+
+ // This is the recommended value, according to documentation in
+ // src/include/opus.h in libopus, so that the Opus encoder does not degrade
+ // the audio due to memory constraints.
+ static constexpr int kOpusMaxPayloadSize = 4000;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_SENDER_STREAMING_OPUS_ENCODER_H_
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc
new file mode 100644
index 00000000000..178564d656d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc
@@ -0,0 +1,492 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/standalone_sender/streaming_vp8_encoder.h"
+
+#include <stdint.h>
+#include <string.h>
+#include <vpx/vp8cx.h>
+
+#include <cmath>
+#include <utility>
+
+#include "cast/streaming/encoded_frame.h"
+#include "cast/streaming/environment.h"
+#include "cast/streaming/sender.h"
+#include "util/logging.h"
+#include "util/saturate_cast.h"
+
+namespace openscreen {
+namespace cast {
+
+using std::chrono::duration_cast;
+using std::chrono::milliseconds;
+using std::chrono::seconds;
+
+// TODO(https://crbug.com/openscreen/123): Fix the declarations and then remove
+// this:
+using openscreen::operator<<; // For std::chrono::duration pretty-printing.
+
+namespace {
+
+constexpr int kBytesPerKilobyte = 1024;
+
+// Lower and upper bounds to the frame duration passed to vpx_codec_encode(), to
+// ensure sanity. Note that the upper-bound is especially important in cases
+// where the video paused for some lengthy amount of time.
+constexpr Clock::duration kMinFrameDuration = milliseconds(1);
+constexpr Clock::duration kMaxFrameDuration = milliseconds(125);
+
+// Highest/lowest allowed encoding speed set to the encoder. The valid range is
+// [4, 16], but experiments show that with speed higher than 12, the saving of
+// the encoding time is not worth the dropping of the quality. And, with speed
+// lower than 6, the increasing amount of quality is not worth the increasing
+// amount of encoding time.
+constexpr int kHighestEncodingSpeed = 12;
+constexpr int kLowestEncodingSpeed = 6;
+
+// This is the equivalent change in encoding speed per one quantizer step.
+constexpr double kEquivalentEncodingSpeedStepPerQuantizerStep = 1 / 20.0;
+
+} // namespace
+
+StreamingVp8Encoder::StreamingVp8Encoder(const Parameters& params,
+ TaskRunner* task_runner,
+ Sender* sender)
+ : params_(params),
+ main_task_runner_(task_runner),
+ sender_(sender),
+ ideal_speed_setting_(kHighestEncodingSpeed),
+ encode_thread_([this] { ProcessWorkUnitsUntilTimeToQuit(); }) {
+ OSP_DCHECK_LE(1, params_.num_encode_threads);
+ OSP_DCHECK_LE(kMinQuantizer, params_.min_quantizer);
+ OSP_DCHECK_LE(params_.min_quantizer, params_.max_cpu_saver_quantizer);
+ OSP_DCHECK_LE(params_.max_cpu_saver_quantizer, params_.max_quantizer);
+ OSP_DCHECK_LE(params_.max_quantizer, kMaxQuantizer);
+ OSP_DCHECK_LT(0.0, params_.max_time_utilization);
+ OSP_DCHECK_LE(params_.max_time_utilization, 1.0);
+ OSP_DCHECK(main_task_runner_);
+ OSP_DCHECK(sender_);
+
+ const auto result =
+ vpx_codec_enc_config_default(vpx_codec_vp8_cx(), &config_, 0);
+ OSP_CHECK_EQ(result, VPX_CODEC_OK);
+
+ // This is set to non-zero in ConfigureForNewFrameSize() later, to flag that
+ // the encoder has been initialized.
+ config_.g_threads = 0;
+
+ // Set the timebase to match that of openscreen::Clock::duration.
+ config_.g_timebase.num = Clock::duration::period::num;
+ config_.g_timebase.den = Clock::duration::period::den;
+
+ // |g_pass| and |g_lag_in_frames| must be "one pass" and zero, respectively,
+ // because of the way the libvpx API is used.
+ config_.g_pass = VPX_RC_ONE_PASS;
+ config_.g_lag_in_frames = 0;
+
+ // Rate control settings.
+ config_.rc_dropframe_thresh = 0; // The encoder may not drop any frames.
+ config_.rc_resize_allowed = 0;
+ config_.rc_end_usage = VPX_CBR;
+ config_.rc_target_bitrate = target_bitrate_ / kBytesPerKilobyte;
+ config_.rc_min_quantizer = params_.min_quantizer;
+ config_.rc_max_quantizer = params_.max_quantizer;
+
+ // The reasons for the values chosen here (rc_*shoot_pct and rc_buf_*_sz) are
+ // lost in history. They were brought-over from the legacy Chrome Cast
+ // Streaming Sender implemenation.
+ config_.rc_undershoot_pct = 100;
+ config_.rc_overshoot_pct = 15;
+ config_.rc_buf_initial_sz = 500;
+ config_.rc_buf_optimal_sz = 600;
+ config_.rc_buf_sz = 1000;
+
+ config_.kf_mode = VPX_KF_DISABLED;
+}
+
+StreamingVp8Encoder::~StreamingVp8Encoder() {
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ target_bitrate_ = 0;
+ cv_.notify_one();
+ }
+ encode_thread_.join();
+}
+
+int StreamingVp8Encoder::GetTargetBitrate() const {
+ // Note: No need to lock the |mutex_| since this method should be called on
+ // the same thread as SetTargetBitrate().
+ return target_bitrate_;
+}
+
+void StreamingVp8Encoder::SetTargetBitrate(int new_bitrate) {
+ // Ensure that, when bps is converted to kbps downstream, that the encoder
+ // bitrate will not be zero.
+ new_bitrate = std::max(new_bitrate, kBytesPerKilobyte);
+
+ std::unique_lock<std::mutex> lock(mutex_);
+ // Only assign the new target bitrate if |target_bitrate_| has not yet been
+ // used to signal the |encode_thread_| to end.
+ if (target_bitrate_ > 0) {
+ target_bitrate_ = new_bitrate;
+ }
+}
+
+void StreamingVp8Encoder::EncodeAndSend(
+ const VideoFrame& frame,
+ Clock::time_point reference_time,
+ std::function<void(Stats)> stats_callback) {
+ WorkUnit work_unit;
+
+ // TODO(miu): The |VideoFrame| struct should provide the media timestamp,
+ // instead of this code inferring it from the reference timestamps, since: 1)
+ // the video capturer's clock may tick at a different rate than the system
+ // clock; and 2) to reduce jitter.
+ if (start_time_ == Clock::time_point::min()) {
+ start_time_ = reference_time;
+ work_unit.rtp_timestamp = RtpTimeTicks();
+ } else {
+ work_unit.rtp_timestamp = RtpTimeTicks::FromTimeSinceOrigin(
+ reference_time - start_time_, sender_->rtp_timebase());
+ if (work_unit.rtp_timestamp <= last_enqueued_rtp_timestamp_) {
+ OSP_LOG_WARN << "VIDEO[" << sender_->ssrc()
+ << "] Dropping: RTP timestamp is not monotonically "
+ "increasing from last frame.";
+ return;
+ }
+ }
+ if (sender_->GetInFlightMediaDuration(work_unit.rtp_timestamp) >
+ sender_->GetMaxInFlightMediaDuration()) {
+ OSP_LOG_WARN << "VIDEO[" << sender_->ssrc()
+ << "] Dropping: In-flight media duration would be too high.";
+ return;
+ }
+
+ Clock::duration frame_duration = frame.duration;
+ if (frame_duration <= Clock::duration::zero()) {
+ // The caller did not provide the frame duration in |frame|.
+ if (reference_time == start_time_) {
+ // Use the max for the first frame so libvpx will spend extra effort on
+ // its quality.
+ frame_duration = kMaxFrameDuration;
+ } else {
+ // Use the actual amount of time between the current and previous frame as
+ // a prediction for the next frame's duration.
+ frame_duration =
+ (work_unit.rtp_timestamp - last_enqueued_rtp_timestamp_)
+ .ToDuration<Clock::duration>(sender_->rtp_timebase());
+ }
+ }
+ work_unit.duration =
+ std::max(std::min(frame_duration, kMaxFrameDuration), kMinFrameDuration);
+
+ last_enqueued_rtp_timestamp_ = work_unit.rtp_timestamp;
+
+ work_unit.image = CloneAsVpxImage(frame);
+ work_unit.reference_time = reference_time;
+ work_unit.stats_callback = std::move(stats_callback);
+ const bool force_key_frame = sender_->NeedsKeyFrame();
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ needs_key_frame_ |= force_key_frame;
+ encode_queue_.push(std::move(work_unit));
+ cv_.notify_one();
+ }
+}
+
+void StreamingVp8Encoder::DestroyEncoder() {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ if (is_encoder_initialized()) {
+ vpx_codec_destroy(&encoder_);
+ // Flag that the encoder is not initialized. See header comments for
+ // is_encoder_initialized().
+ config_.g_threads = 0;
+ }
+}
+
+void StreamingVp8Encoder::ProcessWorkUnitsUntilTimeToQuit() {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ for (;;) {
+ WorkUnitWithResults work_unit{};
+ bool force_key_frame;
+ int target_bitrate;
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (target_bitrate_ <= 0) {
+ break; // Time to end this thread.
+ }
+ if (encode_queue_.empty()) {
+ cv_.wait(lock);
+ if (encode_queue_.empty()) {
+ continue;
+ }
+ }
+ static_cast<WorkUnit&>(work_unit) = std::move(encode_queue_.front());
+ encode_queue_.pop();
+ force_key_frame = needs_key_frame_;
+ target_bitrate = target_bitrate_;
+ }
+
+ // Clock::now() is being called directly, instead of using a
+ // dependency-injected "now function," since actual wall time is being
+ // measured.
+ const Clock::time_point encode_start_time = Clock::now();
+ PrepareEncoder(work_unit.image->d_w, work_unit.image->d_h, target_bitrate);
+ EncodeFrame(force_key_frame, &work_unit);
+ ComputeFrameEncodeStats(Clock::now() - encode_start_time, target_bitrate,
+ &work_unit);
+ UpdateSpeedSettingForNextFrame(work_unit.stats);
+
+ main_task_runner_->PostTask(
+ [this, results = std::move(work_unit)]() mutable {
+ SendEncodedFrame(std::move(results));
+ });
+ }
+
+ DestroyEncoder();
+}
+
+void StreamingVp8Encoder::PrepareEncoder(int width,
+ int height,
+ int target_bitrate) {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ const int target_kbps = target_bitrate / kBytesPerKilobyte;
+
+ // Translate the |ideal_speed_setting_| into the VP8E_SET_CPUUSED setting and
+ // the minimum quantizer to use.
+ int speed;
+ int min_quantizer;
+ if (ideal_speed_setting_ > kHighestEncodingSpeed) {
+ speed = kHighestEncodingSpeed;
+ const double remainder = ideal_speed_setting_ - speed;
+ min_quantizer = rounded_saturate_cast<int>(
+ remainder / kEquivalentEncodingSpeedStepPerQuantizerStep +
+ params_.min_quantizer);
+ min_quantizer = std::min(min_quantizer, params_.max_cpu_saver_quantizer);
+ } else {
+ speed = std::max(rounded_saturate_cast<int>(ideal_speed_setting_),
+ kLowestEncodingSpeed);
+ min_quantizer = params_.min_quantizer;
+ }
+
+ if (static_cast<int>(config_.g_w) != width ||
+ static_cast<int>(config_.g_h) != height) {
+ DestroyEncoder();
+ }
+
+ if (!is_encoder_initialized()) {
+ config_.g_threads = params_.num_encode_threads;
+ config_.g_w = width;
+ config_.g_h = height;
+ config_.rc_target_bitrate = target_kbps;
+ config_.rc_min_quantizer = min_quantizer;
+
+ encoder_ = {};
+ const vpx_codec_flags_t flags = 0;
+ const auto init_result =
+ vpx_codec_enc_init(&encoder_, vpx_codec_vp8_cx(), &config_, flags);
+ OSP_CHECK_EQ(init_result, VPX_CODEC_OK);
+
+ // Raise the threshold for considering macroblocks as static. The default is
+ // zero, so this setting makes the encoder less sensitive to motion. This
+ // lowers the probability of needing to utilize more CPU to search for
+ // motion vectors.
+ const auto ctl_result =
+ vpx_codec_control(&encoder_, VP8E_SET_STATIC_THRESHOLD, 1);
+ OSP_CHECK_EQ(ctl_result, VPX_CODEC_OK);
+
+ // Ensure the speed will be set (below).
+ current_speed_setting_ = ~speed;
+ } else if (static_cast<int>(config_.rc_target_bitrate) != target_kbps ||
+ static_cast<int>(config_.rc_min_quantizer) != min_quantizer) {
+ config_.rc_target_bitrate = target_kbps;
+ config_.rc_min_quantizer = min_quantizer;
+ const auto update_config_result =
+ vpx_codec_enc_config_set(&encoder_, &config_);
+ OSP_CHECK_EQ(update_config_result, VPX_CODEC_OK);
+ }
+
+ if (current_speed_setting_ != speed) {
+ // Pass the |speed| as a negative value to turn off VP8's automatic speed
+ // selection logic and force the exact setting.
+ const auto ctl_result =
+ vpx_codec_control(&encoder_, VP8E_SET_CPUUSED, -speed);
+ OSP_CHECK_EQ(ctl_result, VPX_CODEC_OK);
+ current_speed_setting_ = speed;
+ }
+}
+
+void StreamingVp8Encoder::EncodeFrame(bool force_key_frame,
+ WorkUnitWithResults* work_unit) {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ // The presentation timestamp argument here is fixed to zero to force the
+ // encoder to base its single-frame bandwidth calculations entirely on
+ // |frame_duration| and the target bitrate setting.
+ const vpx_codec_pts_t pts = 0;
+ const vpx_enc_frame_flags_t flags = force_key_frame ? VPX_EFLAG_FORCE_KF : 0;
+ const auto encode_result =
+ vpx_codec_encode(&encoder_, work_unit->image.get(), pts,
+ work_unit->duration.count(), flags, VPX_DL_REALTIME);
+ OSP_CHECK_EQ(encode_result, VPX_CODEC_OK);
+
+ const vpx_codec_cx_pkt_t* pkt;
+ for (vpx_codec_iter_t iter = nullptr;;) {
+ pkt = vpx_codec_get_cx_data(&encoder_, &iter);
+ // vpx_codec_get_cx_data() returns null once the "iteration" is complete.
+ // However, that point should never be reached because a
+ // VPX_CODEC_CX_FRAME_PKT must be encountered before that.
+ OSP_CHECK(pkt);
+ if (pkt->kind == VPX_CODEC_CX_FRAME_PKT) {
+ break;
+ }
+ }
+
+ // A copy of the payload data is being made here. That's okay since it has to
+ // be copied at some point anyway, to be passed back to the main thread.
+ auto* const begin = static_cast<const uint8_t*>(pkt->data.frame.buf);
+ auto* const end = begin + pkt->data.frame.sz;
+ work_unit->payload.assign(begin, end);
+ work_unit->is_key_frame = !!(pkt->data.frame.flags & VPX_FRAME_IS_KEY);
+}
+
+void StreamingVp8Encoder::ComputeFrameEncodeStats(
+ Clock::duration encode_wall_time,
+ int target_bitrate,
+ WorkUnitWithResults* work_unit) {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ Stats& stats = work_unit->stats;
+
+ // Note: stats.frame_id is set later, in SendEncodedFrame().
+ stats.rtp_timestamp = work_unit->rtp_timestamp;
+ stats.encode_wall_time = encode_wall_time;
+ stats.frame_duration = work_unit->duration;
+ stats.encoded_size = work_unit->payload.size();
+
+ constexpr double kBytesPerBit = 1.0 / CHAR_BIT;
+ constexpr double kSecondsPerClockTick =
+ 1.0 / duration_cast<Clock::duration>(seconds(1)).count();
+ const double target_bytes_per_clock_tick =
+ target_bitrate * (kBytesPerBit * kSecondsPerClockTick);
+ stats.target_size = target_bytes_per_clock_tick * work_unit->duration.count();
+
+ // The quantizer the encoder used. This is the result of the VP8 encoder
+ // taking a guess at what quantizer value would produce an encoded frame size
+ // as close to the target as possible.
+ const auto get_quantizer_result = vpx_codec_control(
+ &encoder_, VP8E_GET_LAST_QUANTIZER_64, &stats.quantizer);
+ OSP_CHECK_EQ(get_quantizer_result, VPX_CODEC_OK);
+
+ // Now that the frame has been encoded and the number of bytes is known, the
+ // perfect quantizer value (i.e., the one that should have been used) can be
+ // determined.
+ stats.perfect_quantizer = stats.quantizer * stats.space_utilization();
+}
+
+void StreamingVp8Encoder::UpdateSpeedSettingForNextFrame(const Stats& stats) {
+ OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id());
+
+ // Combine the speed setting that was used to encode the last frame, and the
+ // quantizer the encoder chose into a single speed metric.
+ const double speed = current_speed_setting_ +
+ kEquivalentEncodingSpeedStepPerQuantizerStep *
+ std::max(0, stats.quantizer - params_.min_quantizer);
+
+ // Like |Stats::perfect_quantizer|, this computes a "hindsight" speed setting
+ // for the last frame, one that may have potentially allowed for a
+ // better-quality quantizer choice by the encoder, while also keeping CPU
+ // utilization within budget.
+ const double perfect_speed =
+ speed * stats.time_utilization() / params_.max_time_utilization;
+
+ // Update the ideal speed setting, to be used for the next frame. An
+ // exponentially-decaying weighted average is used here to smooth-out noise.
+ // The weight is based on the duration of the frame that was encoded.
+ constexpr Clock::duration kDecayHalfLife = milliseconds(120);
+ const double ticks = stats.frame_duration.count();
+ const double weight = ticks / (ticks + kDecayHalfLife.count());
+ ideal_speed_setting_ =
+ weight * perfect_speed + (1.0 - weight) * ideal_speed_setting_;
+ OSP_DCHECK(std::isfinite(ideal_speed_setting_));
+}
+
+void StreamingVp8Encoder::SendEncodedFrame(WorkUnitWithResults results) {
+ OSP_DCHECK(main_task_runner_->IsRunningOnTaskRunner());
+
+ EncodedFrame frame;
+ frame.frame_id = sender_->GetNextFrameId();
+ if (results.is_key_frame) {
+ frame.dependency = EncodedFrame::KEY_FRAME;
+ frame.referenced_frame_id = frame.frame_id;
+ } else {
+ frame.dependency = EncodedFrame::DEPENDS_ON_ANOTHER;
+ frame.referenced_frame_id = frame.frame_id - 1;
+ }
+ frame.rtp_timestamp = results.rtp_timestamp;
+ frame.reference_time = results.reference_time;
+ frame.data = absl::Span<uint8_t>(results.payload);
+
+ if (sender_->EnqueueFrame(frame) != Sender::OK) {
+ // Since the frame will not be sent, the encoder's frame dependency chain
+ // has been broken. Force a key frame for the next frame.
+ std::unique_lock<std::mutex> lock(mutex_);
+ needs_key_frame_ = true;
+ }
+
+ if (results.stats_callback) {
+ results.stats.frame_id = frame.frame_id;
+ results.stats_callback(results.stats);
+ }
+}
+
+namespace {
+void CopyPlane(const uint8_t* src,
+ int src_stride,
+ int num_rows,
+ uint8_t* dst,
+ int dst_stride) {
+ if (src_stride == dst_stride) {
+ memcpy(dst, src, src_stride * num_rows);
+ return;
+ }
+ const int bytes_per_row = std::min(src_stride, dst_stride);
+ while (--num_rows >= 0) {
+ memcpy(dst, src, bytes_per_row);
+ dst += dst_stride;
+ src += src_stride;
+ }
+}
+} // namespace
+
+// static
+StreamingVp8Encoder::VpxImageUniquePtr StreamingVp8Encoder::CloneAsVpxImage(
+ const VideoFrame& frame) {
+ OSP_DCHECK_GE(frame.width, 0);
+ OSP_DCHECK_GE(frame.height, 0);
+ OSP_DCHECK_GE(frame.yuv_strides[0], 0);
+ OSP_DCHECK_GE(frame.yuv_strides[1], 0);
+ OSP_DCHECK_GE(frame.yuv_strides[2], 0);
+
+ constexpr int kAlignment = 32;
+ VpxImageUniquePtr image(vpx_img_alloc(nullptr, VPX_IMG_FMT_I420, frame.width,
+ frame.height, kAlignment));
+ OSP_CHECK(image);
+
+ CopyPlane(frame.yuv_planes[0], frame.yuv_strides[0], frame.height,
+ image->planes[VPX_PLANE_Y], image->stride[VPX_PLANE_Y]);
+ CopyPlane(frame.yuv_planes[1], frame.yuv_strides[1], (frame.height + 1) / 2,
+ image->planes[VPX_PLANE_U], image->stride[VPX_PLANE_U]);
+ CopyPlane(frame.yuv_planes[2], frame.yuv_strides[2], (frame.height + 1) / 2,
+ image->planes[VPX_PLANE_V], image->stride[VPX_PLANE_V]);
+
+ return image;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h
new file mode 100644
index 00000000000..1c64cafc5b0
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h
@@ -0,0 +1,302 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_
+#define CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_
+
+#include <vpx/vpx_encoder.h>
+#include <vpx/vpx_image.h>
+
+#include <algorithm>
+#include <condition_variable> // NOLINT
+#include <functional>
+#include <memory>
+#include <mutex> // NOLINT
+#include <queue>
+#include <thread> // NOLINT
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "cast/streaming/frame_id.h"
+#include "cast/streaming/rtp_time.h"
+#include "platform/api/task_runner.h"
+#include "platform/api/time.h"
+
+namespace openscreen {
+
+class TaskRunner;
+
+namespace cast {
+
+class Sender;
+
+// Uses libvpx to encode VP8 video and streams it to a Sender. Includes
+// extensive logic for fine-tuning the encoder parameters in real-time, to
+// provide the best quality results given external, uncontrollable factors:
+// CPU/network availability, and the complexity of the video frame content.
+//
+// Internally, a separate encode thread is created and used to prevent blocking
+// the main thread while frames are being encoded. All public API methods are
+// assumed to be called on the same sequence/thread as the main TaskRunner
+// (injected via the constructor).
+//
+// Usage:
+//
+// 1. EncodeAndSend() is used to queue-up video frames for encoding and sending,
+// which will be done on a best-effort basis.
+//
+// 2. The client is expected to call SetTargetBitrate() frequently based on its
+// own bandwidth estimates and congestion control logic. In addition, a client
+// may provide a callback for each frame's encode statistics, which can be used
+// to further optimize the user experience. For example, the stats can be used
+// as a signal to reduce the data volume (i.e., resolution and/or frame rate)
+// coming from the video capture source.
+class StreamingVp8Encoder {
+ public:
+ // Configurable parameters passed to the StreamingVp8Encoder constructor.
+ struct Parameters {
+ // Number of threads to parallelize frame encoding. This should be set based
+ // on the number of CPU cores available for encoding, but no more than 8.
+ int num_encode_threads =
+ std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8);
+
+ // Best-quality quantizer (lower is better quality). Range: [0,63]
+ int min_quantizer = 4;
+
+ // Worst-quality quantizer (lower is better quality). Range: [0,63]
+ int max_quantizer = 63;
+
+ // Worst-quality quantizer to use when the CPU is extremely constrained.
+ // Range: [min_quantizer,max_quantizer]
+ int max_cpu_saver_quantizer = 25;
+
+ // Maximum amount of wall-time a frame's encode can take, relative to the
+ // frame's duration, before the CPU-saver logic is activated. The default
+ // (70%) is appropriate for systems with four or more cores, but should be
+ // reduced (e.g., 50%) for systems with fewer than three cores.
+ //
+ // Example: For 30 FPS (continuous) video, the frame duration is ~33.3ms,
+ // and a value of 0.5 here would mean that the CPU-saver logic starts
+ // sacrificing quality when frame encodes start taking longer than ~16.7ms.
+ double max_time_utilization = 0.7;
+ };
+
+ // Represents an input VideoFrame, passed to EncodeAndSend().
+ struct VideoFrame {
+ // Image width and height.
+ int width;
+ int height;
+
+ // I420 format image pointers and row strides (the number of bytes between
+ // the start of successive rows). The pointers only need to remain valid
+ // until the EncodeAndSend() call returns.
+ const uint8_t* yuv_planes[3];
+ int yuv_strides[3];
+
+ // How long this frame will be held before the next frame will be displayed,
+ // or zero if unknown. The frame duration is passed to the VP8 codec,
+ // affecting a number of important behaviors, including: per-frame
+ // bandwidth, CPU time spent encoding, temporal quality trade-offs, and
+ // key/golden/alt-ref frame generation intervals.
+ Clock::duration duration;
+ };
+
+ // Performance statistics for a single frame's encode.
+ //
+ // For full details on how to use these stats in an end-to-end system, see:
+ // https://www.chromium.org/developers/design-documents/
+ // auto-throttled-screen-capture-and-mirroring
+ // and https://source.chromium.org/chromium/chromium/src/+/master:
+ // media/cast/sender/performance_metrics_overlay.h
+ struct Stats {
+ // The Cast Streaming ID that was assigned to the frame.
+ FrameId frame_id;
+
+ // The RTP timestamp of the frame.
+ RtpTimeTicks rtp_timestamp;
+
+ // How long the frame took to encode. This is wall time, not CPU time or
+ // some other load metric.
+ Clock::duration encode_wall_time;
+
+ // The frame's predicted duration; or, the actual duration if it was
+ // provided in the VideoFrame.
+ Clock::duration frame_duration;
+
+ // The encoded frame's size in bytes.
+ int encoded_size;
+
+ // The average size of an encoded frame in bytes, having this
+ // |frame_duration| and current target bitrate.
+ double target_size;
+
+ // The actual quantizer the VP8 encoder used, in the range [0,63].
+ int quantizer;
+
+ // The "hindsight" quantizer value that would have produced the best quality
+ // encoding of the frame at the current target bitrate. The nominal range is
+ // [0.0,63.0]. If it is larger than 63.0, then it was impossible for VP8 to
+ // encode the frame within the current target bitrate (e.g., too much
+ // "entropy" in the image, or too low a target bitrate).
+ double perfect_quantizer;
+
+ // Utilization feedback metrics. The nominal range for each of these is
+ // [0.0,1.0] where 1.0 means "the entire budget available for the frame was
+ // exhausted." Going above 1.0 is okay for one or a few frames, since it's
+ // the average over many frames that matters before the system is considered
+ // "redlining."
+ //
+ // The max of these three provides an overall utilization control signal.
+ // The usual approach is for upstream control logic to increase/decrease the
+ // data volume (e.g., video resolution and/or frame rate) to maintain a good
+ // target point.
+ double time_utilization() const {
+ return static_cast<double>(encode_wall_time.count()) /
+ frame_duration.count();
+ }
+ double space_utilization() const { return encoded_size / target_size; }
+ double entropy_utilization() const {
+ return perfect_quantizer / kMaxQuantizer;
+ }
+ };
+
+ StreamingVp8Encoder(const Parameters& params,
+ TaskRunner* task_runner,
+ Sender* sender);
+
+ ~StreamingVp8Encoder();
+
+ // Get/Set the target bitrate. This may be changed at any time, as frequently
+ // as desired, and it will take effect internally as soon as possible.
+ int GetTargetBitrate() const;
+ void SetTargetBitrate(int new_bitrate);
+
+ // Encode |frame| using the VP8 encoder, assemble an EncodedFrame, and enqueue
+ // into the Sender. The frame may be dropped if too many frames are in-flight.
+ // If provided, the |stats_callback| is run after the frame is enqueued in the
+ // Sender (via the main TaskRunner).
+ void EncodeAndSend(const VideoFrame& frame,
+ Clock::time_point reference_time,
+ std::function<void(Stats)> stats_callback);
+
+ static constexpr int kMinQuantizer = 0;
+ static constexpr int kMaxQuantizer = 63;
+
+ private:
+ // Syntactic convenience to wrap the vpx_image_t alloc/free API in a smart
+ // pointer.
+ struct VpxImageDeleter {
+ void operator()(vpx_image_t* ptr) const { vpx_img_free(ptr); }
+ };
+ using VpxImageUniquePtr = std::unique_ptr<vpx_image_t, VpxImageDeleter>;
+
+ // Represents the state of one frame encode. This is created in
+ // EncodeAndSend(), and passed to the encode thread via the |encode_queue_|.
+ struct WorkUnit {
+ VpxImageUniquePtr image;
+ Clock::duration duration;
+ Clock::time_point reference_time;
+ RtpTimeTicks rtp_timestamp;
+ std::function<void(Stats)> stats_callback;
+ };
+
+ // Same as WorkUnit, but with additional fields to carry the encode results.
+ struct WorkUnitWithResults : public WorkUnit {
+ std::vector<uint8_t> payload;
+ bool is_key_frame;
+ Stats stats;
+ };
+
+ bool is_encoder_initialized() const { return config_.g_threads != 0; }
+
+ // Destroys the VP8 encoder context if it has been initialized.
+ void DestroyEncoder();
+
+ // The procedure for the |encode_thread_| that loops, processing work units
+ // from the |encode_queue_| by calling Encode() until it's time to end the
+ // thread.
+ void ProcessWorkUnitsUntilTimeToQuit();
+
+ // If the |encoder_| is live, attempt reconfiguration to allow it to encode
+ // frames at a new frame size, target bitrate, or "CPU encoding speed." If
+ // reconfiguration is not possible, destroy the existing instance and
+ // re-create a new |encoder_| instance.
+ void PrepareEncoder(int width, int height, int target_bitrate);
+
+ // Wraps the complex libvpx vpx_codec_encode() call using inputs from
+ // |work_unit| and populating results there.
+ void EncodeFrame(bool force_key_frame, WorkUnitWithResults* work_unit);
+
+ // Computes and populates |work_unit.stats| after the last call to
+ // EncodeFrame().
+ void ComputeFrameEncodeStats(Clock::duration encode_wall_time,
+ int target_bitrate,
+ WorkUnitWithResults* work_unit);
+
+ // Updates the |ideal_speed_setting_|, to take effect with the next frame
+ // encode, based on the given performance |stats|.
+ void UpdateSpeedSettingForNextFrame(const Stats& stats);
+
+ // Assembles and enqueues an EncodedFrame with the Sender on the main thread.
+ void SendEncodedFrame(WorkUnitWithResults results);
+
+ // Allocates a vpx_image_t and copies the content from |frame| to it.
+ static VpxImageUniquePtr CloneAsVpxImage(const VideoFrame& frame);
+
+ const Parameters params_;
+ TaskRunner* const main_task_runner_;
+ Sender* const sender_;
+
+ // The reference time of the first frame passed to EncodeAndSend().
+ Clock::time_point start_time_ = Clock::time_point::min();
+
+ // The RTP timestamp of the last frame that was pushed into the
+ // |encode_queue_| by EncodeAndSend(). This is used to check whether
+ // timestamps are monotonically increasing.
+ RtpTimeTicks last_enqueued_rtp_timestamp_;
+
+ // Guards a few members shared by both the main and encode threads.
+ std::mutex mutex_;
+
+ // Used by the encode thread to sleep until more work is available.
+ std::condition_variable cv_ ABSL_GUARDED_BY(mutex_);
+
+ // These encode parameters not passed in the WorkUnit struct because it is
+ // desirable for them to be applied as soon as possible, with the very next
+ // WorkUnit popped from the |encode_queue_| on the encode thread, and not to
+ // wait until some later WorkUnit is processed.
+ bool needs_key_frame_ ABSL_GUARDED_BY(mutex_) = true;
+ int target_bitrate_ ABSL_GUARDED_BY(mutex_) = 2 << 20; // Default: 2 Mbps.
+
+ // The queue of frame encodes. The size of this queue is implicitly bounded by
+ // EncodeAndSend(), where it checks for the total in-flight media duration and
+ // maybe drops a frame.
+ std::queue<WorkUnit> encode_queue_ ABSL_GUARDED_BY(mutex_);
+
+ // Current VP8 encoder configuration. Most of the fields are unchanging, and
+ // are populated in the ctor; but thereafter, only the encode thread accesses
+ // this struct.
+ //
+ // The speed setting is controlled via a separate libvpx API (see members
+ // below).
+ vpx_codec_enc_cfg_t config_{};
+
+ // These represent the magnitude of the VP8 speed setting, where larger values
+ // (i.e., faster speed) request less CPU usage but will provide lower video
+ // quality. Only the encode thread accesses these.
+ double ideal_speed_setting_; // A time-weighted average, from measurements.
+ int current_speed_setting_; // Current |encoder_| speed setting.
+
+ // libvpx VP8 encoder instance. Only the encode thread accesses this.
+ vpx_codec_ctx_t encoder_;
+
+ // This member should be last in the class since the thread should not start
+ // until all above members have been initialized by the constructor.
+ std::thread encode_thread_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn
index d45ee925a78..688a453f1ce 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn
+++ b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn
@@ -3,6 +3,7 @@
# found in the LICENSE file.
import("//build_overrides/build.gni")
+import("../../testing/libfuzzer/fuzzer_test.gni")
source_set("common") {
sources = [
@@ -51,10 +52,14 @@ source_set("common") {
source_set("receiver") {
sources = [
+ "answer_messages.cc",
+ "answer_messages.h",
"compound_rtcp_builder.cc",
"compound_rtcp_builder.h",
"frame_collector.cc",
"frame_collector.h",
+ "offer_messages.cc",
+ "offer_messages.h",
"packet_receive_stats_tracker.cc",
"packet_receive_stats_tracker.h",
"receiver.cc",
@@ -72,14 +77,24 @@ source_set("receiver") {
public_deps = [
":common",
]
+
+ deps = [
+ "../../util",
+ ]
}
source_set("sender") {
sources = [
+ "bandwidth_estimator.cc",
+ "bandwidth_estimator.h",
"compound_rtcp_parser.cc",
"compound_rtcp_parser.h",
"rtp_packetizer.cc",
"rtp_packetizer.h",
+ "sender.cc",
+ "sender.h",
+ "sender_packet_router.cc",
+ "sender_packet_router.h",
"sender_report_builder.cc",
"sender_report_builder.h",
]
@@ -93,21 +108,29 @@ source_set("unittests") {
testonly = true
sources = [
+ "answer_messages_unittest.cc",
+ "bandwidth_estimator_unittest.cc",
"compound_rtcp_builder_unittest.cc",
"compound_rtcp_parser_unittest.cc",
"expanded_value_base_unittest.cc",
"frame_collector_unittest.cc",
"frame_crypto_unittest.cc",
"mock_compound_rtcp_parser_client.h",
+ "mock_environment.cc",
+ "mock_environment.h",
"ntp_time_unittest.cc",
+ "offer_messages_unittest.cc",
"packet_receive_stats_tracker_unittest.cc",
"packet_util_unittest.cc",
+ "receiver_session_unittest.cc",
"receiver_unittest.cc",
"rtcp_common_unittest.cc",
"rtp_packet_parser_unittest.cc",
"rtp_packetizer_unittest.cc",
"rtp_time_unittest.cc",
+ "sender_packet_router_unittest.cc",
"sender_report_unittest.cc",
+ "sender_unittest.cc",
"ssrc_unittest.cc",
]
@@ -116,82 +139,54 @@ source_set("unittests") {
":sender",
"../../third_party/googletest:gmock",
"../../third_party/googletest:gtest",
+ "../../util",
+ ]
+}
+
+openscreen_fuzzer_test("compound_rtcp_parser_fuzzer") {
+ sources = [
+ "compound_rtcp_parser_fuzzer.cc",
+ ]
+
+ deps = [
+ ":sender",
+ "../../third_party/abseil",
+ ]
+
+ seed_corpus = "compound_rtcp_parser_fuzzer_seeds"
+
+ # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
+ libfuzzer_options = [ "max_len=1500" ]
+}
+
+openscreen_fuzzer_test("rtp_packet_parser_fuzzer") {
+ sources = [
+ "rtp_packet_parser_fuzzer.cc",
]
+
+ deps = [
+ ":receiver",
+ "../../third_party/abseil",
+ ]
+
+ seed_corpus = "rtp_packet_parser_fuzzer_seeds"
+
+ # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
+ libfuzzer_options = [ "max_len=1500" ]
}
-if (build_with_chromium) {
- import("//testing/libfuzzer/fuzzer_test.gni")
-
- fuzzer_test("compound_rtcp_parser_fuzzer") {
- sources = [
- "compound_rtcp_parser_fuzzer.cc",
- ]
-
- deps = [
- ":sender",
- "../../third_party/abseil",
- ]
-
- seed_corpus = "compound_rtcp_parser_fuzzer_seeds"
-
- # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
- libfuzzer_options = [ "max_len=1500" ]
- }
-
- fuzzer_test("rtp_packet_parser_fuzzer") {
- sources = [
- "rtp_packet_parser_fuzzer.cc",
- ]
-
- deps = [
- ":receiver",
- "../../third_party/abseil",
- ]
-
- seed_corpus = "rtp_packet_parser_fuzzer_seeds"
-
- # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
- libfuzzer_options = [ "max_len=1500" ]
- }
-
- fuzzer_test("sender_report_parser_fuzzer") {
- sources = [
- "sender_report_parser_fuzzer.cc",
- ]
-
- deps = [
- ":receiver",
- "../../third_party/abseil",
- ]
-
- seed_corpus = "sender_report_parser_fuzzer_seeds"
-
- # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
- libfuzzer_options = [ "max_len=1500" ]
- }
-} else {
- # Note: The following is commented out because, as of this writing, the LLVM
- # toolchain we pull does not include libclang_rt.fuzzer-x86_64.a, the
- # libFuzzer library *with* a main() to drive everything. Thus, the only way to
- # get things working is to specify an exact path to the fuzzer_no_main variant
- # of the library that *is* avalable, and then provide our own main(). In
- # summary, what you see below demonstrates how to get it working specifically
- # for Clang 9.0.0 on Linux x86_64. One need only modify the "libs = [...]" for
- # a different Clang, OS, or architecture.
- # if (is_clang) {
- # executable("rtp_packet_parser_fuzzer") {
- # testonly = true
- # defines = [ "NEEDS_MAIN_TO_CALL_FUZZER_DRIVER" ]
- # sources = [
- # "rtp_packet_parser_fuzzer.cc",
- # ]
- # cflags_cc = [ "-fsanitize=address,fuzzer-no-link,undefined" ]
- # ldflags = [ "-fsanitize=address,undefined" ]
- # libs = [ "$clang_base_path/lib/clang/9.0.0/lib/linux/libclang_rt.fuzzer_no_main-x86_64.a" ]
- # deps = [
- # ":receiver",
- # "../../third_party/abseil",
- # ]
- # }
- # }
+openscreen_fuzzer_test("sender_report_parser_fuzzer") {
+ sources = [
+ "sender_report_parser_fuzzer.cc",
+ ]
+
+ deps = [
+ ":receiver",
+ "../../third_party/abseil",
+ ]
+
+ seed_corpus = "sender_report_parser_fuzzer_seeds"
+
+ # Note: 1500 is approx. kMaxRtpPacketSize in rtp_defines.h.
+ libfuzzer_options = [ "max_len=1500" ]
}
diff --git a/chromium/third_party/openscreen/src/cast/streaming/DEPS b/chromium/third_party/openscreen/src/cast/streaming/DEPS
index 03a77eeaa00..de4027bf0d5 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/DEPS
+++ b/chromium/third_party/openscreen/src/cast/streaming/DEPS
@@ -6,5 +6,6 @@ include_rules = [
'+cast/common',
'+cast/receiver',
'+cast/sender',
- '+openssl'
+ '+openssl',
+ '+json',
]
diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc
new file mode 100644
index 00000000000..bd60a58c11f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc
@@ -0,0 +1,208 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/answer_messages.h"
+
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "cast/streaming/message_util.h"
+#include "platform/base/error.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+static constexpr char kMessageKeyType[] = "type";
+static constexpr char kMessageTypeAnswer[] = "ANSWER";
+
+// List of ANSWER message fields.
+static constexpr char kAnswerMessageBody[] = "answer";
+static constexpr char kResult[] = "result";
+static constexpr char kResultOk[] = "ok";
+static constexpr char kResultError[] = "error";
+static constexpr char kErrorMessageBody[] = "error";
+static constexpr char kErrorCode[] = "code";
+static constexpr char kErrorDescription[] = "description";
+
+Json::Value AspectRatioConstraintToJson(AspectRatioConstraint aspect_ratio) {
+ switch (aspect_ratio) {
+ case AspectRatioConstraint::kVariable:
+ return Json::Value("receiver");
+ case AspectRatioConstraint::kFixed:
+ default:
+ return Json::Value("sender");
+ }
+}
+
+template <typename T>
+Json::Value PrimitiveVectorToJson(const std::vector<T>& vec) {
+ Json::Value array(Json::ValueType::arrayValue);
+ array.resize(vec.size());
+
+ for (Json::Value::ArrayIndex i = 0; i < vec.size(); ++i) {
+ array[i] = Json::Value(vec[i]);
+ }
+
+ return array;
+}
+
+} // namespace
+
+ErrorOr<Json::Value> AudioConstraints::ToJson() const {
+ if (max_sample_rate <= 0 || max_channels <= 0 || min_bit_rate <= 0 ||
+ max_bit_rate < min_bit_rate) {
+ return CreateParameterError("AudioConstraints");
+ }
+
+ Json::Value root;
+ root["maxSampleRate"] = max_sample_rate;
+ root["maxChannels"] = max_channels;
+ root["minBitRate"] = min_bit_rate;
+ root["maxBitRate"] = max_bit_rate;
+ root["maxDelay"] = Json::Value::Int64(max_delay.count());
+ return root;
+}
+
+ErrorOr<Json::Value> Dimensions::ToJson() const {
+ if (width <= 0 || height <= 0 || !frame_rate.is_defined() ||
+ !frame_rate.is_positive()) {
+ return CreateParameterError("Dimensions");
+ }
+
+ Json::Value root;
+ root["width"] = width;
+ root["height"] = height;
+ root["frameRate"] = frame_rate.ToString();
+ return root;
+}
+
+ErrorOr<Json::Value> VideoConstraints::ToJson() const {
+ if (max_pixels_per_second <= 0 || min_bit_rate <= 0 ||
+ max_bit_rate < min_bit_rate || max_delay.count() <= 0) {
+ return CreateParameterError("VideoConstraints");
+ }
+
+ auto error_or_min_dim = min_dimensions.ToJson();
+ if (error_or_min_dim.is_error()) {
+ return error_or_min_dim.error();
+ }
+
+ auto error_or_max_dim = max_dimensions.ToJson();
+ if (error_or_max_dim.is_error()) {
+ return error_or_max_dim.error();
+ }
+
+ Json::Value root;
+ root["maxPixelsPerSecond"] = max_pixels_per_second;
+ root["minDimensions"] = error_or_min_dim.value();
+ root["maxDimensions"] = error_or_max_dim.value();
+ root["minBitRate"] = min_bit_rate;
+ root["maxBitRate"] = max_bit_rate;
+ root["maxDelay"] = Json::Value::Int64(max_delay.count());
+ return root;
+}
+
+ErrorOr<Json::Value> Constraints::ToJson() const {
+ auto audio_or_error = audio.ToJson();
+ if (audio_or_error.is_error()) {
+ return audio_or_error.error();
+ }
+
+ auto video_or_error = video.ToJson();
+ if (video_or_error.is_error()) {
+ return video_or_error.error();
+ }
+
+ Json::Value root;
+ root["audio"] = audio_or_error.value();
+ root["video"] = video_or_error.value();
+ return root;
+}
+
+ErrorOr<Json::Value> DisplayDescription::ToJson() const {
+ if (aspect_ratio.width < 1 || aspect_ratio.height < 1) {
+ return CreateParameterError("DisplayDescription");
+ }
+
+ auto dimensions_or_error = dimensions.ToJson();
+ if (dimensions_or_error.is_error()) {
+ return dimensions_or_error.error();
+ }
+
+ Json::Value root;
+ root["dimensions"] = dimensions_or_error.value();
+ root["aspectRatio"] =
+ absl::StrCat(aspect_ratio.width, ":", aspect_ratio.height);
+ root["scaling"] = AspectRatioConstraintToJson(aspect_ratio_constraint);
+ return root;
+}
+
+ErrorOr<Json::Value> Answer::ToJson() const {
+ if (udp_port <= 0 || udp_port > 65535) {
+ return CreateParameterError("Answer - UDP Port number");
+ }
+
+ Json::Value root;
+ if (constraints) {
+ auto constraints_or_error = constraints.value().ToJson();
+ if (constraints_or_error.is_error()) {
+ return constraints_or_error.error();
+ }
+ root["constraints"] = constraints_or_error.value();
+ }
+
+ if (display) {
+ auto display_or_error = display.value().ToJson();
+ if (display_or_error.is_error()) {
+ return display_or_error.error();
+ }
+ root["display"] = display_or_error.value();
+ }
+
+ root["castMode"] = cast_mode.ToString();
+ root["udpPort"] = udp_port;
+ root["receiverGetStatus"] = supports_wifi_status_reporting;
+ root["sendIndexes"] = PrimitiveVectorToJson(send_indexes);
+ root["ssrcs"] = PrimitiveVectorToJson(ssrcs);
+ if (!receiver_rtcp_event_log.empty()) {
+ root["receiverRtcpEventLog"] =
+ PrimitiveVectorToJson(receiver_rtcp_event_log);
+ }
+ if (!receiver_rtcp_dscp.empty()) {
+ root["receiverRtcpDscp"] = PrimitiveVectorToJson(receiver_rtcp_dscp);
+ }
+ if (!rtp_extensions.empty()) {
+ root["rtpExtensions"] = PrimitiveVectorToJson(rtp_extensions);
+ }
+ return root;
+}
+
+Json::Value Answer::ToAnswerMessage() const {
+ auto json_or_error = ToJson();
+ if (json_or_error.is_error()) {
+ return CreateInvalidAnswer(json_or_error.error());
+ }
+
+ Json::Value message_root;
+ message_root[kMessageKeyType] = kMessageTypeAnswer;
+ message_root[kAnswerMessageBody] = std::move(json_or_error.value());
+ message_root[kResult] = kResultOk;
+ return message_root;
+}
+
+Json::Value CreateInvalidAnswer(Error error) {
+ Json::Value message_root;
+ message_root[kMessageKeyType] = kMessageTypeAnswer;
+ message_root[kResult] = kResultError;
+ message_root[kErrorMessageBody][kErrorCode] = static_cast<int>(error.code());
+ message_root[kErrorMessageBody][kErrorDescription] = error.message();
+
+ return message_root;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h
new file mode 100644
index 00000000000..60b9a49479a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h
@@ -0,0 +1,119 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_ANSWER_MESSAGES_H_
+#define CAST_STREAMING_ANSWER_MESSAGES_H_
+
+#include <array>
+#include <chrono> // NOLINT
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "cast/streaming/offer_messages.h"
+#include "cast/streaming/ssrc.h"
+#include "json/value.h"
+#include "platform/base/error.h"
+#include "util/simple_fraction.h"
+
+namespace openscreen {
+namespace cast {
+
+struct AudioConstraints {
+ int max_sample_rate = 0;
+ int max_channels = 0;
+ // Technically optional, sender will assume 32kbps if omitted.
+ int min_bit_rate = 0;
+ int max_bit_rate = 0;
+ std::chrono::milliseconds max_delay = {};
+
+ ErrorOr<Json::Value> ToJson() const;
+};
+
+struct Dimensions {
+ int width = 0;
+ int height = 0;
+ SimpleFraction frame_rate;
+
+ ErrorOr<Json::Value> ToJson() const;
+};
+
+struct VideoConstraints {
+ double max_pixels_per_second = {};
+ Dimensions min_dimensions = {};
+ Dimensions max_dimensions = {};
+ // Technically optional, sender will assume 300kbps if omitted.
+ int min_bit_rate = 0;
+ int max_bit_rate = 0;
+ std::chrono::milliseconds max_delay = {};
+
+ ErrorOr<Json::Value> ToJson() const;
+};
+
+struct Constraints {
+ AudioConstraints audio;
+ VideoConstraints video;
+
+ ErrorOr<Json::Value> ToJson() const;
+};
+
+// Decides whether the Sender scales and letterboxes content to 16:9, or if
+// it may send video frames of any arbitrary size and the Receiver must
+// handle the presentation details.
+enum class AspectRatioConstraint : uint8_t { kVariable = 0, kFixed };
+
+struct AspectRatio {
+ int width = 0;
+ int height = 0;
+};
+
+struct DisplayDescription {
+ // May exceed, be the same, or less than those mentioned in the
+ // video constraints.
+ Dimensions dimensions;
+ AspectRatio aspect_ratio = {};
+ AspectRatioConstraint aspect_ratio_constraint = {};
+
+ ErrorOr<Json::Value> ToJson() const;
+};
+
+struct Answer {
+ CastMode cast_mode = {};
+ int udp_port = 0;
+ std::vector<int> send_indexes;
+ std::vector<Ssrc> ssrcs;
+
+ // Constraints and display descriptions are optional fields, and maybe null in
+ // the valid case.
+ absl::optional<Constraints> constraints;
+ absl::optional<DisplayDescription> display;
+ std::vector<int> receiver_rtcp_event_log;
+ std::vector<int> receiver_rtcp_dscp;
+ bool supports_wifi_status_reporting = false;
+
+ // RTP extensions should be empty, but not null.
+ std::vector<std::string> rtp_extensions = {};
+
+ // ToJson performs a standard serialization, returning an error if this
+ // instance failed to serialize properly.
+ ErrorOr<Json::Value> ToJson() const;
+
+ // In constrast to ToJson, ToAnswerMessage performs a successful serialization
+ // even if the answer object is malformed, by complying to the spec's
+ // error answer message format in this case.
+ Json::Value ToAnswerMessage() const;
+};
+
+// Helper method that creates an invalid Answer response. Exposed publicly
+// here as it is called in ToAnswerMessage(), but can also be called by
+// the receiver session.
+Json::Value CreateInvalidAnswer(Error error);
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_ANSWER_MESSAGES_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc
new file mode 100644
index 00000000000..d1c708281d9
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc
@@ -0,0 +1,180 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/answer_messages.h"
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "util/json/json_serialization.h"
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+const Answer kValidAnswer{
+ CastMode{CastMode::Type::kMirroring},
+ 1234, // udp_port
+ std::vector<int>{1, 2, 3}, // send_indexes
+ std::vector<Ssrc>{123, 456}, // ssrcs
+ Constraints{
+ AudioConstraints{
+ 96000, // max_sample_rate
+ 7, // max_channels
+ 32000, // min_bit_rate
+ 96000, // max_bit_rate
+ std::chrono::milliseconds(2000) // max_delay
+ }, // audio
+ VideoConstraints{
+ 40000.0, // max_pixels_per_second
+ Dimensions{
+ 320, // width
+ 480, // height
+ SimpleFraction{15000, 101} // frame_rate
+ }, // min_dimensions
+ Dimensions{
+ 1920, // width
+ 1080, // height
+ SimpleFraction{288, 2} // frame_rate
+ },
+ 300000, // min_bit_rate
+ 144000000, // max_bit_rate
+ std::chrono::milliseconds(3000) // max_delay
+ } // video
+ }, // constraints
+ DisplayDescription{
+ Dimensions{
+ 640, // width
+ 480, // height
+ SimpleFraction{30, 1} // frame_rate
+ },
+ AspectRatio{16, 9}, // aspect_ratio
+ AspectRatioConstraint::kFixed, // scaling
+ },
+ std::vector<int>{7, 8, 9}, // receiver_rtcp_event_log
+ std::vector<int>{11, 12, 13}, // receiver_rtcp_dscp
+ true, // receiver_get_status
+ std::vector<std::string>{"foo", "bar"} // rtp_extensions
+};
+
+} // anonymous namespace
+
+TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) {
+ auto value_or_error = kValidAnswer.ToJson();
+ EXPECT_TRUE(value_or_error.is_value());
+
+ Json::Value root = std::move(value_or_error.value());
+ EXPECT_EQ(root["castMode"], "mirroring");
+ EXPECT_EQ(root["udpPort"], 1234);
+
+ Json::Value sendIndexes = std::move(root["sendIndexes"]);
+ EXPECT_EQ(sendIndexes.type(), Json::ValueType::arrayValue);
+ EXPECT_EQ(sendIndexes[0], 1);
+ EXPECT_EQ(sendIndexes[1], 2);
+ EXPECT_EQ(sendIndexes[2], 3);
+
+ Json::Value ssrcs = std::move(root["ssrcs"]);
+ EXPECT_EQ(ssrcs.type(), Json::ValueType::arrayValue);
+ EXPECT_EQ(ssrcs[0], 123u);
+ EXPECT_EQ(ssrcs[1], 456u);
+
+ Json::Value constraints = std::move(root["constraints"]);
+ Json::Value audio = std::move(constraints["audio"]);
+ EXPECT_EQ(audio.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(audio["maxSampleRate"], 96000);
+ EXPECT_EQ(audio["maxChannels"], 7);
+ EXPECT_EQ(audio["minBitRate"], 32000);
+ EXPECT_EQ(audio["maxBitRate"], 96000);
+ EXPECT_EQ(audio["maxDelay"], 2000);
+
+ Json::Value video = std::move(constraints["video"]);
+ EXPECT_EQ(video.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(video["maxPixelsPerSecond"], 40000.0);
+ EXPECT_EQ(video["minBitRate"], 300000);
+ EXPECT_EQ(video["maxBitRate"], 144000000);
+ EXPECT_EQ(video["maxDelay"], 3000);
+
+ Json::Value min_dimensions = std::move(video["minDimensions"]);
+ EXPECT_EQ(min_dimensions.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(min_dimensions["width"], 320);
+ EXPECT_EQ(min_dimensions["height"], 480);
+ EXPECT_EQ(min_dimensions["frameRate"], "15000/101");
+
+ Json::Value max_dimensions = std::move(video["maxDimensions"]);
+ EXPECT_EQ(max_dimensions.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(max_dimensions["width"], 1920);
+ EXPECT_EQ(max_dimensions["height"], 1080);
+ EXPECT_EQ(max_dimensions["frameRate"], "288/2");
+
+ Json::Value display = std::move(root["display"]);
+ EXPECT_EQ(display.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(display["aspectRatio"], "16:9");
+ EXPECT_EQ(display["scaling"], "sender");
+
+ Json::Value dimensions = std::move(display["dimensions"]);
+ EXPECT_EQ(dimensions.type(), Json::ValueType::objectValue);
+ EXPECT_EQ(dimensions["width"], 640);
+ EXPECT_EQ(dimensions["height"], 480);
+ EXPECT_EQ(dimensions["frameRate"], "30");
+
+ Json::Value receiver_rtcp_event_log = std::move(root["receiverRtcpEventLog"]);
+ EXPECT_EQ(receiver_rtcp_event_log.type(), Json::ValueType::arrayValue);
+ EXPECT_EQ(receiver_rtcp_event_log[0], 7);
+ EXPECT_EQ(receiver_rtcp_event_log[1], 8);
+ EXPECT_EQ(receiver_rtcp_event_log[2], 9);
+
+ Json::Value receiver_rtcp_dscp = std::move(root["receiverRtcpDscp"]);
+ EXPECT_EQ(receiver_rtcp_dscp.type(), Json::ValueType::arrayValue);
+ EXPECT_EQ(receiver_rtcp_dscp[0], 11);
+ EXPECT_EQ(receiver_rtcp_dscp[1], 12);
+ EXPECT_EQ(receiver_rtcp_dscp[2], 13);
+
+ EXPECT_EQ(root["receiverGetStatus"], true);
+
+ Json::Value rtp_extensions = std::move(root["rtpExtensions"]);
+ EXPECT_EQ(rtp_extensions.type(), Json::ValueType::arrayValue);
+ EXPECT_EQ(rtp_extensions[0], "foo");
+ EXPECT_EQ(rtp_extensions[1], "bar");
+}
+
+TEST(AnswerMessagesTest, InvalidDimensionsCauseError) {
+ Answer invalid_dimensions = kValidAnswer;
+ invalid_dimensions.display.value().dimensions.width = -1;
+ auto value_or_error = invalid_dimensions.ToJson();
+ EXPECT_TRUE(value_or_error.is_error());
+}
+
+TEST(AnswerMessagesTest, InvalidAudioConstraintsCauseError) {
+ Answer invalid_audio = kValidAnswer;
+ invalid_audio.constraints.value().audio.max_bit_rate =
+ invalid_audio.constraints.value().audio.min_bit_rate - 1;
+ auto value_or_error = invalid_audio.ToJson();
+ EXPECT_TRUE(value_or_error.is_error());
+}
+
+TEST(AnswerMessagesTest, InvalidVideoConstraintsCauseError) {
+ Answer invalid_video = kValidAnswer;
+ invalid_video.constraints.value().video.max_pixels_per_second = -1.0;
+ auto value_or_error = invalid_video.ToJson();
+ EXPECT_TRUE(value_or_error.is_error());
+}
+
+TEST(AnswerMessagesTest, InvalidDisplayDescriptionsCauseError) {
+ Answer invalid_display = kValidAnswer;
+ invalid_display.display.value().aspect_ratio = {0, 0};
+ auto value_or_error = invalid_display.ToJson();
+ EXPECT_TRUE(value_or_error.is_error());
+}
+
+TEST(AnswerMessagesTest, InvalidUdpPortsCauseError) {
+ Answer invalid_port = kValidAnswer;
+ invalid_port.udp_port = 65536;
+ auto value_or_error = invalid_port.ToJson();
+ EXPECT_TRUE(value_or_error.is_error());
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc
new file mode 100644
index 00000000000..e3386b0071a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc
@@ -0,0 +1,157 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/bandwidth_estimator.h"
+
+#include <algorithm>
+
+#include "util/logging.h"
+#include "util/saturate_cast.h"
+
+namespace openscreen {
+namespace cast {
+
+using openscreen::operator<<; // For std::chrono::duration logging.
+
+namespace {
+
+// Converts units from |bytes| per |time_window| number of Clock ticks into
+// bits-per-second.
+int ToClampedBitsPerSecond(int32_t bytes, Clock::duration time_window) {
+ OSP_DCHECK_GT(time_window, Clock::duration::zero());
+
+ // Divide |bytes| by |time_window| and scale the units to bits per second.
+ constexpr int64_t kBitsPerByte = 8;
+ constexpr int64_t kClockTicksPerSecond =
+ std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1))
+ .count();
+ const int64_t bits = bytes * kBitsPerByte;
+ const int64_t bits_per_second =
+ (bits * kClockTicksPerSecond) / time_window.count();
+ return saturate_cast<int>(bits_per_second);
+}
+
+} // namespace
+
+BandwidthEstimator::BandwidthEstimator(int max_packets_per_timeslice,
+ Clock::duration timeslice_duration,
+ Clock::time_point start_time)
+ : max_packets_per_history_window_(max_packets_per_timeslice *
+ kNumTimeslices),
+ history_window_(timeslice_duration * kNumTimeslices),
+ burst_history_(timeslice_duration, start_time),
+ feedback_history_(timeslice_duration, start_time) {
+ OSP_DCHECK_GT(max_packets_per_timeslice, 0);
+ OSP_DCHECK_GT(timeslice_duration, Clock::duration::zero());
+}
+
+BandwidthEstimator::~BandwidthEstimator() = default;
+
+void BandwidthEstimator::OnBurstComplete(int num_packets_sent,
+ Clock::time_point when) {
+ OSP_DCHECK_GE(num_packets_sent, 0);
+ burst_history_.Accumulate(num_packets_sent, when);
+}
+
+void BandwidthEstimator::OnRtcpReceived(
+ Clock::time_point arrival_time,
+ Clock::duration estimated_round_trip_time) {
+ OSP_DCHECK_GE(estimated_round_trip_time, Clock::duration::zero());
+ // Move forward the feedback history tracking timeline to include the latest
+ // moment a packet could have left the Sender.
+ feedback_history_.AdvanceToIncludeTime(arrival_time -
+ estimated_round_trip_time);
+}
+
+void BandwidthEstimator::OnPayloadReceived(
+ int payload_bytes_acknowledged,
+ Clock::time_point ack_arrival_time,
+ Clock::duration estimated_round_trip_time) {
+ OSP_DCHECK_GE(payload_bytes_acknowledged, 0);
+ OSP_DCHECK_GE(estimated_round_trip_time, Clock::duration::zero());
+ // Track the bytes in terms of when the last packet was sent.
+ feedback_history_.Accumulate(payload_bytes_acknowledged,
+ ack_arrival_time - estimated_round_trip_time);
+}
+
+int BandwidthEstimator::ComputeNetworkBandwidth() const {
+ // Determine whether the |burst_history_| time window overlaps with the
+ // |feedback_history_| time window by at least half. The time windows don't
+ // have to overlap entirely because the calculations are averaging all the
+ // measurements (i.e., recent typical behavior). Though, they should overlap
+ // by "enough" so that the measurements correlate "enough."
+ const Clock::time_point overlap_begin =
+ std::max(burst_history_.begin_time(), feedback_history_.begin_time());
+ const Clock::time_point overlap_end =
+ std::min(burst_history_.end_time(), feedback_history_.end_time());
+ if ((overlap_end - overlap_begin) < (history_window_ / 2)) {
+ return 0;
+ }
+
+ const int32_t num_packets_transmitted = burst_history_.Sum();
+ if (num_packets_transmitted <= 0) {
+ // Cannot estimate because there have been no transmissions recently.
+ return 0;
+ }
+ const Clock::duration transmit_duration = history_window_ *
+ num_packets_transmitted /
+ max_packets_per_history_window_;
+ const int32_t num_bytes_received = feedback_history_.Sum();
+ return ToClampedBitsPerSecond(num_bytes_received, transmit_duration);
+}
+
+// static
+constexpr int BandwidthEstimator::kNumTimeslices;
+
+BandwidthEstimator::FlowTracker::FlowTracker(Clock::duration timeslice_duration,
+ Clock::time_point begin_time)
+ : timeslice_duration_(timeslice_duration), begin_time_(begin_time) {}
+
+BandwidthEstimator::FlowTracker::~FlowTracker() = default;
+
+void BandwidthEstimator::FlowTracker::AdvanceToIncludeTime(
+ Clock::time_point until) {
+ if (until < end_time()) {
+ return; // Not advancing.
+ }
+
+ // Step forward in time, at timeslice granularity.
+ const int64_t num_periods = 1 + (until - end_time()) / timeslice_duration_;
+ begin_time_ += num_periods * timeslice_duration_;
+
+ // Shift the ring elements, discarding N oldest timeslices, and creating N new
+ // ones initialized to zero.
+ const int shift_count = std::min<int64_t>(num_periods, kNumTimeslices);
+ for (int i = 0; i < shift_count; ++i) {
+ history_ring_[tail_++] = 0;
+ }
+}
+
+void BandwidthEstimator::FlowTracker::Accumulate(int32_t amount,
+ Clock::time_point when) {
+ if (when < begin_time_) {
+ return; // Ignore a data point that is already too old.
+ }
+
+ AdvanceToIncludeTime(when);
+
+ // Because of the AdvanceToIncludeTime() call just made, the offset/index
+ // calculations here are guaranteed to point to a valid element in the
+ // |history_ring_|.
+ const int64_t offset_from_first = (when - begin_time_) / timeslice_duration_;
+ const index_mod_256_t ring_index = tail_ + offset_from_first;
+ int32_t& timeslice = history_ring_[ring_index];
+ timeslice = saturate_cast<int32_t>(int64_t{timeslice} + amount);
+}
+
+int32_t BandwidthEstimator::FlowTracker::Sum() const {
+ int64_t result = 0;
+ for (int32_t amount : history_ring_) {
+ result += amount;
+ }
+ return saturate_cast<int32_t>(result);
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h
new file mode 100644
index 00000000000..f8ce7af2088
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.h
@@ -0,0 +1,170 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_
+#define CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_
+
+#include <stdint.h>
+
+#include <limits>
+
+#include "platform/api/time.h"
+
+namespace openscreen {
+namespace cast {
+
+// Tracks send attempts and successful receives, and then computes a total
+// network bandwith estimate.
+//
+// Two metrics are tracked by the BandwidthEstimator, over a "recent history"
+// time window:
+//
+// 1. The number of packets sent during bursts (see SenderPacketRouter for
+// explanation of what a "burst" is). These track when the network was
+// actually in-use for transmission and the magnitude of each burst. When
+// computing bandwidth, the estimator assumes the timeslices where the
+// network was not in-use could have been used to send even more bytes at
+// the same rate.
+//
+// 2. Successful receipt of payload bytes over time, or a lack thereof.
+// Packets that include acknowledgements from the Receivers are providing
+// proof of the successful receipt of payload bytes. All other packets
+// provide proof of network connectivity over time, and are used to
+// identify periods of time where nothing was received.
+//
+// The BandwidthEstimator assumes a simplified model for streaming over the
+// network. The model does not include any detailed knowledge about things like
+// protocol overhead, packet re-transmits, parasitic bufferring, network
+// reliability, etc. Instead, it automatically accounts for all such things by
+// looking at what's actually leaving the Senders and what's actually making it
+// to the Receivers.
+//
+// This simplified model does produce some known inaccuracies in the resulting
+// estimations. If no data has recently been transmitted (or been received),
+// estimations cannot be provided. If the transmission rate is near (or
+// exceeding) the network's capacity, the estimations will be very accurate. In
+// between those two extremes, the logic will tend to under-estimate the
+// network's capacity. However, those under-estimates will still be far larger
+// than the current transmission rate.
+//
+// Thus, these estimates can be used effectively as a control signal for
+// congestion control in upstream code modules. The logic computing the media's
+// encoding target bitrate should be adjusted in realtime using a TCP-like
+// congestion control algorithm:
+//
+// 1. When the estimated bitrate is less than the current encoding target
+// bitrate, aggressively and immediately decrease the encoding bitrate.
+//
+// 2. When the estimated bitrate is more than the current encoding target
+// bitrate, gradually increase the encoding bitrate (up to the maximum
+// that is reasonable for the application).
+class BandwidthEstimator {
+ public:
+ // |max_packets_per_timeslice| and |timeslice_duration| should match the burst
+ // configuration in SenderPacketRouter. |start_time| should be a recent
+ // point-in-time before the first packet is sent.
+ BandwidthEstimator(int max_packets_per_timeslice,
+ Clock::duration timeslice_duration,
+ Clock::time_point start_time);
+
+ ~BandwidthEstimator();
+
+ // Returns the duration of the fixed, recent-history time window over which
+ // data flows are being tracked.
+ Clock::duration history_window() const { return history_window_; }
+
+ // Records |when| burst-sending was active or inactive. For the active case,
+ // |num_packets_sent| should include all network packets sent, including
+ // non-payload packets (since both affect the modeled utilization/capacity).
+ // For the inactive case, this method should be called with zero for
+ // |num_packets_sent|.
+ void OnBurstComplete(int num_packets_sent, Clock::time_point when);
+
+ // Records when a RTCP packet was received. It's important for Senders to call
+ // this any time a packet comes in from the Receivers, even if no payload is
+ // being acknowledged, since the time windows of "nothing successfully
+ // received" is also important information to track.
+ void OnRtcpReceived(Clock::time_point arrival_time,
+ Clock::duration estimated_round_trip_time);
+
+ // Records that some number of payload bytes has been acknowledged (i.e.,
+ // successfully received).
+ void OnPayloadReceived(int payload_bytes_acknowledged,
+ Clock::time_point ack_arrival_time,
+ Clock::duration estimated_round_trip_time);
+
+ // Computes the current network bandwith estimate. Returns 0 if this cannot be
+ // determined due to a lack of sufficiently-recent data.
+ int ComputeNetworkBandwidth() const;
+
+ private:
+ // FlowTracker (below) manages a ring buffer of size 256. It simplifies the
+ // index calculations to use an integer data type where all arithmetic is mod
+ // 256.
+ using index_mod_256_t = uint8_t;
+ static constexpr int kNumTimeslices =
+ static_cast<int>(std::numeric_limits<index_mod_256_t>::max()) + 1;
+
+ // Tracks volume (e.g., the total number of payload bytes) over a fixed
+ // recent-history time window. The time window is divided up into a number of
+ // identical timeslices, each of which represents the total number of bytes
+ // that flowed during a certain period of time. The data is accumulated in
+ // ring buffer elements so that old data points drop-off as newer ones (that
+ // move the history window forward) are added.
+ class FlowTracker {
+ public:
+ FlowTracker(Clock::duration timeslice_duration,
+ Clock::time_point begin_time);
+ ~FlowTracker();
+
+ Clock::time_point begin_time() const { return begin_time_; }
+ Clock::time_point end_time() const {
+ return begin_time_ + timeslice_duration_ * kNumTimeslices;
+ }
+
+ // Advance the end of the time window being tracked such that the
+ // most-recent timeslice includes |until|. Too-old timeslices are dropped
+ // and new ones are initialized to a zero amount.
+ void AdvanceToIncludeTime(Clock::time_point until);
+
+ // Accumulate the given |amount| into the timeslice that includes |when|.
+ void Accumulate(int32_t amount, Clock::time_point when);
+
+ // Return the sum of all the amounts in recent history. This clamps to the
+ // valid range of int32_t, if necessary.
+ int32_t Sum() const;
+
+ private:
+ const Clock::duration timeslice_duration_;
+
+ // The beginning of the oldest timeslice in the recent-history time window,
+ // the one pointed to by |tail_|.
+ Clock::time_point begin_time_;
+
+ // A ring buffer tracking the accumulated amount for each timeslice.
+ int32_t history_ring_[kNumTimeslices]{};
+
+ // The index of the oldest timeslice in the |history_ring_|. This can also
+ // be thought of, equivalently, as the index just after the most-recent
+ // timeslice.
+ index_mod_256_t tail_ = 0;
+ };
+
+ // The maximum number of packet sends that could possibly be attempted during
+ // the recent-history time window.
+ const int max_packets_per_history_window_;
+
+ // The range of time being tracked.
+ const Clock::duration history_window_;
+
+ // History tracking for send attempts, and success feeback. These timeseries
+ // are in terms of when packets have left the Senders.
+ FlowTracker burst_history_;
+ FlowTracker feedback_history_;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_BANDWIDTH_ESTIMATOR_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc
new file mode 100644
index 00000000000..d6d8570a0a5
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc
@@ -0,0 +1,232 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/bandwidth_estimator.h"
+
+#include <limits>
+#include <random>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/api/time.h"
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+using std::chrono::duration_cast;
+using std::chrono::milliseconds;
+using std::chrono::seconds;
+
+using openscreen::operator<<; // For std::chrono::duration gtest pretty-print.
+
+// BandwidthEstimator configuration common to all tests.
+constexpr int kMaxPacketsPerTimeslice = 10;
+constexpr Clock::duration kTimesliceDuration = milliseconds(10);
+constexpr int kTimeslicesPerSecond = seconds(1) / kTimesliceDuration;
+
+// Use a fake, fixed start time.
+constexpr Clock::time_point kStartTime =
+ Clock::time_point() + Clock::duration(1234567890);
+
+// Range of "fuzz" to add to timestamps in BandwidthEstimatorTest::AddFuzz().
+constexpr Clock::duration kMaxFuzzOffset = milliseconds(15);
+constexpr int kFuzzLowerBoundClockTicks = (-kMaxFuzzOffset).count();
+constexpr int kFuzzUpperBoundClockTicks = kMaxFuzzOffset.count();
+
+class BandwidthEstimatorTest : public testing::Test {
+ public:
+ BandwidthEstimatorTest()
+ : estimator_(kMaxPacketsPerTimeslice, kTimesliceDuration, kStartTime) {}
+
+ BandwidthEstimator* estimator() { return &estimator_; }
+
+ // Returns |t| plus or minus |kMaxFuzzOffset|.
+ Clock::time_point AddFuzz(Clock::time_point t) {
+ return t + Clock::duration(distribution_(rand_));
+ }
+
+ private:
+ BandwidthEstimator estimator_;
+
+ // These are used to generate random values for AddFuzz().
+ static constexpr std::minstd_rand::result_type kRandSeed =
+ kStartTime.time_since_epoch().count();
+ std::minstd_rand rand_{kRandSeed};
+ std::uniform_int_distribution<int> distribution_{kFuzzLowerBoundClockTicks,
+ kFuzzUpperBoundClockTicks};
+};
+
+// Tests that, without any data, there won't be any estimates.
+TEST_F(BandwidthEstimatorTest, DoesNotEstimateWithoutAnyInput) {
+ EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth());
+}
+
+// Tests the case where packets are being sent, but the Receiver hasn't provided
+// feedback (e.g., due to a network blackout).
+TEST_F(BandwidthEstimatorTest, DoesNotEstimateWithoutFeedback) {
+ Clock::time_point now = kStartTime;
+ for (int i = 0; i < 3; ++i) {
+ const Clock::time_point end = now + estimator()->history_window();
+ for (; now < end; now += kTimesliceDuration) {
+ estimator()->OnBurstComplete(i, now);
+ EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth());
+ }
+ now = end;
+ }
+}
+
+// Tests the case where packets are being sent, and a connection to the Receiver
+// has been confirmed (because RTCP packets are coming in), but the Receiver has
+// not successfully received anything.
+TEST_F(BandwidthEstimatorTest, DoesNotEstimateIfNothingSuccessfullyReceived) {
+ const Clock::duration kRoundTripTime = milliseconds(1);
+
+ Clock::time_point now = kStartTime;
+ for (int i = 0; i < 3; ++i) {
+ const Clock::time_point end = now + estimator()->history_window();
+ for (; now < end; now += kTimesliceDuration) {
+ estimator()->OnBurstComplete(i, now);
+ estimator()->OnRtcpReceived(now + kRoundTripTime, kRoundTripTime);
+ EXPECT_EQ(0, estimator()->ComputeNetworkBandwidth());
+ }
+ now = end;
+ }
+}
+
+// Tests that, when the Receiver successfully receives the payload bytes at a
+// fixed rate, the network bandwidth estimates are based on the amount of time
+// the Sender spent transmitting.
+TEST_F(BandwidthEstimatorTest, EstimatesAtVariousBurstSaturations) {
+ // These determine how many packets to burst in the simulation below.
+ constexpr int kDivisors[] = {
+ 1, // Burst 100% of max possible packets.
+ 2, // Burst 50% of max possible packets.
+ 5, // Burst 20% of max possible packets.
+ };
+
+ const Clock::duration kRoundTripTime = milliseconds(1);
+
+ constexpr int kReceivedBytesPerSecond = 256000;
+ constexpr int kReceivedBytesPerTimeslice =
+ kReceivedBytesPerSecond / kTimeslicesPerSecond;
+ static_assert(kReceivedBytesPerSecond % kTimeslicesPerSecond == 0,
+ "Test expectations won't account for rounding errors.");
+
+ ASSERT_EQ(0, estimator()->ComputeNetworkBandwidth());
+
+ // Simulate bursting at various rates, and confirm the bandwidth estimate is
+ // increasing for each burst rate. The estimate should be increasing because
+ // the total time spent transmitting is decreasing (while the same number of
+ // bytes are being received).
+ Clock::time_point now = kStartTime;
+ for (int divisor : kDivisors) {
+ SCOPED_TRACE(testing::Message() << "divisor=" << divisor);
+
+ const Clock::time_point end = now + estimator()->history_window();
+ for (; now < end; now += kTimesliceDuration) {
+ estimator()->OnBurstComplete(kMaxPacketsPerTimeslice / divisor, now);
+ const Clock::time_point rtcp_arrival_time = now + kRoundTripTime;
+ estimator()->OnPayloadReceived(kReceivedBytesPerTimeslice,
+ rtcp_arrival_time, kRoundTripTime);
+ estimator()->OnRtcpReceived(rtcp_arrival_time, kRoundTripTime);
+ }
+ now = end;
+
+ const int estimate = estimator()->ComputeNetworkBandwidth();
+ EXPECT_EQ(divisor * kReceivedBytesPerSecond * CHAR_BIT, estimate);
+ }
+}
+
+// Tests that magnitude of the network round trip times, as well as random
+// variance in packet arrival times, do not have a significant effect on the
+// bandwidth estimates.
+TEST_F(BandwidthEstimatorTest, EstimatesIndependentOfFeedbackDelays) {
+ constexpr int kFactor = 2;
+ constexpr int kPacketsPerBurst = kMaxPacketsPerTimeslice / kFactor;
+ static_assert(kMaxPacketsPerTimeslice % kFactor == 0, "wanted exactly half");
+
+ constexpr milliseconds kRoundTripTimes[3] = {milliseconds(1), milliseconds(9),
+ milliseconds(42)};
+
+ constexpr int kReceivedBytesPerSecond = 2000000;
+ constexpr int kReceivedBytesPerTimeslice =
+ kReceivedBytesPerSecond / kTimeslicesPerSecond;
+
+ // An arbitrary threshold. Sources of error include anything that would place
+ // byte flows outside the history window (e.g., AddFuzz(), or the 42ms round
+ // trip time).
+ constexpr int kMaxErrorPercent = 3;
+
+ Clock::time_point now = kStartTime;
+ for (Clock::duration round_trip_time : kRoundTripTimes) {
+ SCOPED_TRACE(testing::Message()
+ << "round_trip_time=" << round_trip_time.count());
+
+ const Clock::time_point end = now + estimator()->history_window();
+ for (; now < end; now += kTimesliceDuration) {
+ estimator()->OnBurstComplete(kPacketsPerBurst, now);
+ const Clock::time_point rtcp_arrival_time =
+ AddFuzz(now + round_trip_time);
+ estimator()->OnPayloadReceived(kReceivedBytesPerTimeslice,
+ rtcp_arrival_time, round_trip_time);
+ estimator()->OnRtcpReceived(rtcp_arrival_time, round_trip_time);
+ }
+ now = end;
+
+ constexpr int kExpectedEstimate =
+ kFactor * kReceivedBytesPerSecond * CHAR_BIT;
+ constexpr int kMaxError = kExpectedEstimate * kMaxErrorPercent / 100;
+ EXPECT_NEAR(kExpectedEstimate, estimator()->ComputeNetworkBandwidth(),
+ kMaxError);
+ }
+}
+
+// Tests that both the history tracking internal to BandwidthEstimator, as well
+// as its computation of the bandwidth estimate, are both resistant to possible
+// integer overflow cases. The internal implementation always clamps to the
+// valid range of int.
+TEST_F(BandwidthEstimatorTest, ClampsEstimateToMaxInt) {
+ constexpr int kPacketsPerBurst = kMaxPacketsPerTimeslice / 5;
+ static_assert(kMaxPacketsPerTimeslice % 5 == 0, "wanted exactly 20%");
+ const Clock::duration kRoundTripTime = milliseconds(1);
+
+ int last_estimate = estimator()->ComputeNetworkBandwidth();
+ ASSERT_EQ(last_estimate, 0);
+
+ // Simulate increasing numbers of bytes received per timeslice until it
+ // reaches values near INT_MAX. Along the way, the bandwidth estimates
+ // themselves should start clamping and, because of the fuzz added to RTCP
+ // arrival times, individual buckets in BandwidthEstimator::FlowTracker will
+ // occassionally be clamped too.
+ Clock::time_point now = kStartTime;
+ for (int bytes_received_per_timeslice = 1;
+ bytes_received_per_timeslice > 0 /* not overflowed past INT_MAX */;
+ bytes_received_per_timeslice *= 2) {
+ SCOPED_TRACE(testing::Message() << "bytes_received_per_timeslice="
+ << bytes_received_per_timeslice);
+
+ const Clock::time_point end = now + estimator()->history_window() / 4;
+ for (; now < end; now += kTimesliceDuration) {
+ estimator()->OnBurstComplete(kPacketsPerBurst, now);
+ const Clock::time_point rtcp_arrival_time = AddFuzz(now + kRoundTripTime);
+ estimator()->OnPayloadReceived(bytes_received_per_timeslice,
+ rtcp_arrival_time, kRoundTripTime);
+ estimator()->OnRtcpReceived(rtcp_arrival_time, kRoundTripTime);
+ }
+ now = end;
+
+ const int estimate = estimator()->ComputeNetworkBandwidth();
+ EXPECT_LE(last_estimate, estimate);
+ last_estimate = estimate;
+ }
+
+ // Confirm there was a loop iteration at which the estimate reached INT_MAX
+ // and then stayed there for successive loop iterations.
+ EXPECT_EQ(std::numeric_limits<int>::max(), last_estimate);
+}
+
+} // namespace
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc
index bb829578c53..92072733657 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.cc
@@ -7,13 +7,13 @@
#include <cmath>
#include "util/logging.h"
+#include "util/saturate_cast.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
-constexpr ClockDriftSmoother::Clock::time_point kNullTime =
- ClockDriftSmoother::Clock::time_point::min();
+constexpr Clock::time_point kNullTime = Clock::time_point::min();
}
ClockDriftSmoother::ClockDriftSmoother(Clock::duration time_constant)
@@ -25,15 +25,10 @@ ClockDriftSmoother::ClockDriftSmoother(Clock::duration time_constant)
ClockDriftSmoother::~ClockDriftSmoother() = default;
-ClockDriftSmoother::Clock::duration ClockDriftSmoother::Current() const {
+Clock::duration ClockDriftSmoother::Current() const {
OSP_DCHECK(last_update_time_ != kNullTime);
- const double rounded_estimate = std::round(estimated_tick_offset_);
- if (rounded_estimate < Clock::duration::min().count()) {
- return Clock::duration::min();
- } else if (rounded_estimate > Clock::duration::max().count()) {
- return Clock::duration::max();
- }
- return Clock::duration(static_cast<Clock::duration::rep>(rounded_estimate));
+ return Clock::duration(
+ rounded_saturate_cast<Clock::duration::rep>(estimated_tick_offset_));
}
void ClockDriftSmoother::Reset(Clock::time_point now,
@@ -68,5 +63,5 @@ void ClockDriftSmoother::Update(Clock::time_point now,
// static
constexpr std::chrono::seconds ClockDriftSmoother::kDefaultTimeConstant;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h
index d48d166c6ce..97fe1ec601b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/clock_drift_smoother.h
@@ -9,16 +9,14 @@
#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Tracks the jitter and drift between clocks, providing a smoothed offset.
// Internally, a Simple IIR filter is used to maintain a running average that
// moves at a rate based on the passage of time.
class ClockDriftSmoother {
public:
- using Clock = openscreen::platform::Clock;
-
// |time_constant| is the amount of time an impulse signal takes to decay by
// ~62.6%. Interpretation: If the value passed to several Update() calls is
// held constant for T seconds, then the running average will have moved
@@ -54,7 +52,7 @@ class ClockDriftSmoother {
double estimated_tick_offset_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_CLOCK_DRIFT_SMOOTHER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc
index b3cba68b109..37625d051af 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.cc
@@ -14,11 +14,8 @@
#include "util/logging.h"
#include "util/std_util.h"
-using openscreen::AreElementsSortedAndUnique;
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
CompoundRtcpBuilder::CompoundRtcpBuilder(RtcpSession* session)
: session_(session) {
@@ -293,7 +290,7 @@ void CompoundRtcpBuilder::AppendCastFeedbackAckFields(
// Compute how many additional octets are needed.
constexpr int kIncrement = sizeof(uint32_t);
const int num_additional =
- openscreen::DividePositivesRoundingUp(
+ DividePositivesRoundingUp(
(octet_index + 1) - num_ack_bitvector_octets, kIncrement) *
kIncrement;
@@ -328,5 +325,5 @@ void CompoundRtcpBuilder::AppendCastFeedbackAckFields(
acks_for_next_packet_.clear();
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h
index 58bc62fba5b..787a6e5231f 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder.h
@@ -16,8 +16,8 @@
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtp_defines.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
class RtcpSession;
@@ -97,9 +97,8 @@ class CompoundRtcpBuilder {
// should be monotonically increasing so the consuming side (the Sender) can
// determine the chronological ordering of RTCP packets. The Sender might also
// use this to estimate round-trip times over the network.
- absl::Span<uint8_t> BuildPacket(
- openscreen::platform::Clock::time_point send_time,
- absl::Span<uint8_t> buffer);
+ absl::Span<uint8_t> BuildPacket(Clock::time_point send_time,
+ absl::Span<uint8_t> buffer);
// The required buffer size to be provided to BuildPacket(). This accounts for
// all the possible headers and report structures that might be included,
@@ -111,9 +110,8 @@ class CompoundRtcpBuilder {
// Helper methods called by BuildPacket() to append one RTCP packet to the
// |buffer| that will ultimately contain a "compound RTCP packet."
void AppendReceiverReportPacket(absl::Span<uint8_t>* buffer);
- void AppendReceiverReferenceTimeReportPacket(
- openscreen::platform::Clock::time_point send_time,
- absl::Span<uint8_t>* buffer);
+ void AppendReceiverReferenceTimeReportPacket(Clock::time_point send_time,
+ absl::Span<uint8_t>* buffer);
void AppendPictureLossIndicatorPacket(absl::Span<uint8_t>* buffer);
void AppendCastFeedbackPacket(absl::Span<uint8_t>* buffer);
int AppendCastFeedbackLossFields(absl::Span<uint8_t>* buffer);
@@ -122,7 +120,7 @@ class CompoundRtcpBuilder {
RtcpSession* const session_;
// Data to include in the next built RTCP packet.
- FrameId checkpoint_frame_id_ = FrameId::first() - 1;
+ FrameId checkpoint_frame_id_ = FrameId::leader();
std::chrono::milliseconds playout_delay_ = kDefaultTargetPlayoutDelay;
absl::optional<RtcpReportBlock> receiver_report_for_next_packet_;
std::vector<PacketNack> nacks_for_next_packet_;
@@ -134,7 +132,7 @@ class CompoundRtcpBuilder {
uint8_t feedback_count_ = 0;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_COMPOUND_RTCP_BUILDER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc
index ab6f8c00eab..969056cfa40 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc
@@ -15,16 +15,14 @@
#include "gtest/gtest.h"
#include "platform/api/time.h"
-using openscreen::platform::Clock;
-
using testing::_;
using testing::Invoke;
using testing::Mock;
using testing::SaveArg;
using testing::StrictMock;
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
constexpr Ssrc kSenderSsrc{1};
@@ -369,5 +367,5 @@ TEST_F(CompoundRtcpBuilderTest, WithEverythingThatCanFit) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc
index af5336ea71c..260989b0cb1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.cc
@@ -11,10 +11,8 @@
#include "util/logging.h"
#include "util/std_util.h"
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -185,7 +183,7 @@ bool CompoundRtcpParser::Parse(absl::Span<const uint8_t> buffer,
client_->OnReceiverCheckpoint(checkpoint_frame_id, target_playout_delay);
}
if (!received_frames.empty()) {
- OSP_DCHECK(openscreen::AreElementsSortedAndUnique(received_frames));
+ OSP_DCHECK(AreElementsSortedAndUnique(received_frames));
client_->OnReceiverHasFrames(std::move(received_frames));
}
CanonicalizePacketNackVector(&packet_nacks);
@@ -338,7 +336,7 @@ bool CompoundRtcpParser::ParseExtendedReports(
return false; // Length field must always be 2 words.
}
*receiver_reference_time = session_->ntp_converter().ToLocalTime(
- openscreen::ReadBigEndian<uint64_t>(in.data()));
+ ReadBigEndian<uint64_t>(in.data()));
} else {
// Ignore any other type of extended report.
}
@@ -377,5 +375,5 @@ void CompoundRtcpParser::Client::OnReceiverHasFrames(
void CompoundRtcpParser::Client::OnReceiverIsMissingPackets(
std::vector<PacketNack> nacks) {}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h
index c9b71594a7a..c74bb3ec010 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser.h
@@ -14,8 +14,8 @@
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtp_defines.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
class RtcpSession;
@@ -41,7 +41,7 @@ class CompoundRtcpParser {
// Called when a Receiver Reference Time Report has been parsed.
virtual void OnReceiverReferenceTimeAdvanced(
- openscreen::platform::Clock::time_point reference_time);
+ Clock::time_point reference_time);
// Called when a Receiver Report with a Report Block has been parsed.
virtual void OnReceiverReport(const RtcpReportBlock& receiver_report);
@@ -101,9 +101,8 @@ class CompoundRtcpParser {
std::chrono::milliseconds* target_playout_delay,
std::vector<FrameId>* received_frames,
std::vector<PacketNack>* packet_nacks);
- bool ParseExtendedReports(
- absl::Span<const uint8_t> in,
- openscreen::platform::Clock::time_point* receiver_reference_time);
+ bool ParseExtendedReports(absl::Span<const uint8_t> in,
+ Clock::time_point* receiver_reference_time);
bool ParsePictureLossIndicator(absl::Span<const uint8_t> in,
bool* picture_loss_indicator);
@@ -113,10 +112,10 @@ class CompoundRtcpParser {
// Tracks the latest timestamp seen from any Receiver Reference Time Report,
// and uses this to ignore stale RTCP packets that arrived out-of-order and/or
// late from the network.
- openscreen::platform::Clock::time_point latest_receiver_timestamp_;
+ Clock::time_point latest_receiver_timestamp_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_COMPOUND_RTCP_PARSER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc
index af5823df945..bb3dd179b4b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_fuzzer.cc
@@ -9,10 +9,10 @@
#include "cast/streaming/rtcp_session.h"
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
- using cast::streaming::CompoundRtcpParser;
- using cast::streaming::FrameId;
- using cast::streaming::RtcpSession;
- using cast::streaming::Ssrc;
+ using openscreen::cast::CompoundRtcpParser;
+ using openscreen::cast::FrameId;
+ using openscreen::cast::RtcpSession;
+ using openscreen::cast::Ssrc;
constexpr Ssrc kSenderSsrcInSeedCorpus = 1;
constexpr Ssrc kReceiverSsrcInSeedCorpus = 2;
@@ -25,7 +25,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wexit-time-destructors"
static RtcpSession session(kSenderSsrcInSeedCorpus, kReceiverSsrcInSeedCorpus,
- openscreen::platform::Clock::time_point{});
+ openscreen::Clock::time_point{});
static CompoundRtcpParser::Client client_that_ignores_everything;
static CompoundRtcpParser parser(&session, &client_that_ignores_everything);
#pragma clang diagnostic pop
diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc
index 863953afad5..9f8e3c50988 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc
@@ -18,8 +18,8 @@ using testing::Mock;
using testing::SaveArg;
using testing::StrictMock;
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
constexpr Ssrc kSenderSsrc{1};
@@ -32,8 +32,7 @@ class CompoundRtcpParserTest : public testing::Test {
CompoundRtcpParser* parser() { return &parser_; }
private:
- RtcpSession session_{kSenderSsrc, kReceiverSsrc,
- openscreen::platform::Clock::now()};
+ RtcpSession session_{kSenderSsrc, kReceiverSsrc, Clock::now()};
StrictMock<MockCompoundRtcpParserClient> client_;
CompoundRtcpParser parser_{&session_, &client_};
};
@@ -404,5 +403,5 @@ TEST_F(CompoundRtcpParserTest, ParsesFeedbackWithAcks) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/constants.h b/chromium/third_party/openscreen/src/cast/streaming/constants.h
index d67ba3e9164..4a8d526e072 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/constants.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/constants.h
@@ -14,13 +14,23 @@
#include <chrono>
#include <ratio>
+namespace openscreen {
namespace cast {
-namespace streaming {
// Default target playout delay. The playout delay is the window of time between
// capture from the source until presentation at the receiver.
constexpr std::chrono::milliseconds kDefaultTargetPlayoutDelay(400);
+// Default UDP port, bound at the Receiver, for Cast Streaming. An
+// implementation is required to use the port specified by the Receiver in its
+// ANSWER control message, which may or may not match this port number here.
+constexpr int kDefaultCastStreamingPort = 2344;
+
+// Default TCP port, bound at the TLS server socket level, for Cast Streaming.
+// An implementation must use the port specified in the DNS-SD published record
+// for connecting over TLS, which may or may not match this port number here.
+constexpr int kDefaultCastPort = 8010;
+
// Target number of milliseconds between the sending of RTCP reports. Both
// senders and receivers regularly send RTCP reports to their peer.
constexpr std::chrono::milliseconds kRtcpReportInterval(500);
@@ -34,11 +44,14 @@ constexpr std::chrono::milliseconds kRtcpReportInterval(500);
// logic can handle wrap around and compare two frame IDs meaningfully.
constexpr int kMaxUnackedFrames = 120;
+// The network must support a packet size of at least this many bytes.
+constexpr int kRequiredNetworkPacketSize = 256;
+
// The spec declares RTP timestamps must always have a timebase of 90000 ticks
// per second for video.
using kVideoTimebase = std::ratio<1, 90000>;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_CONSTANTS_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc
index ebd6cc753f9..4fb832bd1aa 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.cc
@@ -4,8 +4,8 @@
#include "cast/streaming/encoded_frame.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
EncodedFrame::EncodedFrame() = default;
EncodedFrame::~EncodedFrame() = default;
@@ -22,5 +22,5 @@ void EncodedFrame::CopyMetadataTo(EncodedFrame* dest) const {
dest->new_playout_delay = this->new_playout_delay;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h
index b42b6340d88..40fbda81ca1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/encoded_frame.h
@@ -16,8 +16,8 @@
#include "platform/api/time.h"
#include "platform/base/macros.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// A combination of metadata and data for one encoded frame. This can contain
// audio data or video data or other.
@@ -78,7 +78,7 @@ struct EncodedFrame {
// (see |rtp_timestamp|, above). It is also meant to be used to synchronize
// the presentation of multiple streams (e.g., audio and video), commonly
// known as "lip-sync." It is NOT meant to be a mandatory/exact playout time.
- openscreen::platform::Clock::time_point reference_time;
+ Clock::time_point reference_time;
// Playout delay for this and all future frames. Used by the Adaptive
// Playout delay extension. Non-positive values means no change.
@@ -93,7 +93,7 @@ struct EncodedFrame {
OSP_DISALLOW_COPY_AND_ASSIGN(EncodedFrame);
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_ENCODED_FRAME_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.cc b/chromium/third_party/openscreen/src/cast/streaming/environment.cc
index 954a2e0c535..b3a6c7b1673 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/environment.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/environment.cc
@@ -8,18 +8,8 @@
#include "platform/api/task_runner.h"
#include "util/logging.h"
-using openscreen::Error;
-using openscreen::ErrorOr;
-using openscreen::IPAddress;
-using openscreen::IPEndpoint;
-using openscreen::platform::Clock;
-using openscreen::platform::ClockNowFunctionPtr;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::UdpPacket;
-using openscreen::platform::UdpSocket;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
Environment::Environment(ClockNowFunctionPtr now_function,
TaskRunner* task_runner)
@@ -124,10 +114,10 @@ void Environment::OnRead(UdpSocket* socket,
// Ideally, the arrival time would come from the operating system's network
// stack (e.g., by using the SO_TIMESTAMP sockopt on POSIX systems). However,
// there would still be the problem of mapping the timestamp to a value in
- // terms of platform::Clock. So, just sample the Clock here and call that the
- // "arrival time." While this can add variance within the system, it should be
- // minimal, assuming not too much time has elapsed between the actual packet
- // receive event and the when this code here is executing.
+ // terms of Clock::time_point. So, just sample the Clock here and call that
+ // the "arrival time." While this can add variance within the system, it
+ // should be minimal, assuming not too much time has elapsed between the
+ // actual packet receive event and the when this code here is executing.
const Clock::time_point arrival_time = now_function_();
UdpPacket packet = std::move(packet_or_error.value());
@@ -136,5 +126,5 @@ void Environment::OnRead(UdpSocket* socket,
std::move(static_cast<std::vector<uint8_t>&>(packet)));
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.h b/chromium/third_party/openscreen/src/cast/streaming/environment.h
index 0bc99f2db58..278604b0d0f 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/environment.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/environment.h
@@ -16,24 +16,17 @@
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
-class TaskRunner;
-} // namespace platform
-} // namespace openscreen
-
namespace cast {
-namespace streaming {
// Provides the common environment for operating system resources shared by
// multiple components.
-class Environment : public openscreen::platform::UdpSocket::Client {
+class Environment : public UdpSocket::Client {
public:
class PacketConsumer {
public:
- virtual void OnReceivedPacket(
- const openscreen::IPEndpoint& source,
- openscreen::platform::Clock::time_point arrival_time,
- std::vector<uint8_t> packet) = 0;
+ virtual void OnReceivedPacket(const IPEndpoint& source,
+ Clock::time_point arrival_time,
+ std::vector<uint8_t> packet) = 0;
protected:
virtual ~PacketConsumer();
@@ -42,35 +35,34 @@ class Environment : public openscreen::platform::UdpSocket::Client {
// Construct with the given clock source and TaskRunner. Creates and
// internally-owns a UdpSocket, and immediately binds it to the given
// |local_endpoint|.
- Environment(openscreen::platform::ClockNowFunctionPtr now_function,
- openscreen::platform::TaskRunner* task_runner,
- const openscreen::IPEndpoint& local_endpoint);
+ Environment(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
+ const IPEndpoint& local_endpoint);
~Environment() override;
- openscreen::platform::ClockNowFunctionPtr now_function() const {
- return now_function_;
- }
- openscreen::platform::TaskRunner* task_runner() const { return task_runner_; }
+ ClockNowFunctionPtr now_function() const { return now_function_; }
+ Clock::time_point now() const { return now_function_(); }
+ TaskRunner* task_runner() const { return task_runner_; }
// Returns the local endpoint the socket is bound to, or the zero IPEndpoint
// if socket creation/binding failed.
- openscreen::IPEndpoint GetBoundLocalEndpoint() const;
+ //
+ // Note: This method is virtual to allow unit tests to fake that there really
+ // is a bound socket.
+ virtual IPEndpoint GetBoundLocalEndpoint() const;
// Set a handler function to run whenever non-recoverable socket errors occur.
// If never set, the default is to emit log messages at error priority.
- void set_socket_error_handler(
- std::function<void(openscreen::Error)> handler) {
+ void set_socket_error_handler(std::function<void(Error)> handler) {
socket_error_handler_ = handler;
}
// Get/Set the remote endpoint. This is separate from the constructor because
// the remote endpoint is, in some cases, discovered only after receiving a
// packet.
- const openscreen::IPEndpoint& remote_endpoint() const {
- return remote_endpoint_;
- }
- void set_remote_endpoint(const openscreen::IPEndpoint& endpoint) {
+ const IPEndpoint& remote_endpoint() const { return remote_endpoint_; }
+ void set_remote_endpoint(const IPEndpoint& endpoint) {
remote_endpoint_ = endpoint;
}
@@ -98,34 +90,29 @@ class Environment : public openscreen::platform::UdpSocket::Client {
// Common constructor that just stores the injected dependencies and does not
// create a socket. Subclasses use this to provide an alternative packet
// receive/send mechanism (e.g., for testing).
- Environment(openscreen::platform::ClockNowFunctionPtr now_function,
- openscreen::platform::TaskRunner* task_runner);
+ Environment(ClockNowFunctionPtr now_function, TaskRunner* task_runner);
private:
- // openscreen::platform::UdpSocket::Client implementation.
- void OnError(openscreen::platform::UdpSocket* socket,
- openscreen::Error error) final;
- void OnSendError(openscreen::platform::UdpSocket* socket,
- openscreen::Error error) final;
- void OnRead(openscreen::platform::UdpSocket* socket,
- openscreen::ErrorOr<openscreen::platform::UdpPacket>
- packet_or_error) final;
-
- const openscreen::platform::ClockNowFunctionPtr now_function_;
- openscreen::platform::TaskRunner* const task_runner_;
+ // UdpSocket::Client implementation.
+ void OnError(UdpSocket* socket, Error error) final;
+ void OnSendError(UdpSocket* socket, Error error) final;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet_or_error) final;
+
+ const ClockNowFunctionPtr now_function_;
+ TaskRunner* const task_runner_;
// The UDP socket bound to the local endpoint that was passed into the
// constructor, or null if socket creation failed.
- const std::unique_ptr<openscreen::platform::UdpSocket> socket_;
+ const std::unique_ptr<UdpSocket> socket_;
// These are externally set/cleared. Behaviors are described in getter/setter
// method comments above.
- std::function<void(openscreen::Error)> socket_error_handler_;
- openscreen::IPEndpoint remote_endpoint_{};
+ std::function<void(Error)> socket_error_handler_;
+ IPEndpoint remote_endpoint_{};
PacketConsumer* packet_consumer_ = nullptr;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_ENVIRONMENT_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h
index 0ac0dbebc8d..d418df647df 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base.h
@@ -11,8 +11,8 @@
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Abstract base template class for common "sequence value" data types such as
// RtpTimeTicks, FrameId, or PacketId which generally increment/decrement in
@@ -158,7 +158,7 @@ class ExpandedValueBase {
FullWidthInteger value_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_EXPANDED_VALUE_BASE_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc
index 220af2215a5..bcbf0d9f8f3 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/expanded_value_base_unittest.cc
@@ -6,8 +6,8 @@
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -104,5 +104,5 @@ TEST(ExpandedValueBaseTest, TruncationAndExpansion) {
}
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc
index aabbdd06738..dbd6fb4db9c 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.cc
@@ -12,8 +12,8 @@
#include "cast/streaming/rtp_defines.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -156,5 +156,5 @@ void FrameCollector::Reset() {
FrameCollector::PayloadChunk::PayloadChunk() = default;
FrameCollector::PayloadChunk::~PayloadChunk() = default;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h
index 7504d690958..ca68ab31b0b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector.h
@@ -13,8 +13,8 @@
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtp_packet_parser.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Used by a Receiver to collect the parts of a frame, track what is
// missing/complete, and assemble a complete frame.
@@ -84,7 +84,7 @@ class FrameCollector {
std::vector<PayloadChunk> chunks_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_FRAME_COLLECTOR_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc
index 53e13811464..34448c24af4 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_collector_unittest.cc
@@ -15,8 +15,8 @@
#include "cast/streaming/rtp_time.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
const FrameId kSomeFrameId = FrameId::first() + 39;
@@ -207,5 +207,5 @@ TEST(FrameCollectorTest, RejectsInvalidParts) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc
index 416c2df1235..4d506d5eb7e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.cc
@@ -12,8 +12,8 @@
#include "util/big_endian.h"
#include "util/crypto/openssl_util.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
EncryptedFrame::EncryptedFrame() {
data = absl::Span<uint8_t>(owned_data_);
@@ -91,8 +91,7 @@ void FrameCrypto::EncryptCommon(FrameId frame_id,
std::array<uint8_t, 16> aes_nonce{/* zero initialized */};
static_assert(AES_BLOCK_SIZE == sizeof(aes_nonce),
"AES_BLOCK_SIZE is not 16 bytes.");
- openscreen::WriteBigEndian<uint32_t>(frame_id.lower_32_bits(),
- aes_nonce.data() + 8);
+ WriteBigEndian<uint32_t>(frame_id.lower_32_bits(), aes_nonce.data() + 8);
for (size_t i = 0; i < aes_nonce.size(); ++i) {
aes_nonce[i] ^= cast_iv_mask_[i];
}
@@ -116,5 +115,5 @@ std::array<uint8_t, 16> FrameCrypto::GenerateRandomBytes() {
return result;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h
index 693385ad311..35ee9787cb1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto.h
@@ -16,8 +16,8 @@
#include "openssl/aes.h"
#include "platform/base/macros.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
class FrameCollector;
class FrameCrypto;
@@ -91,7 +91,7 @@ class FrameCrypto {
absl::Span<uint8_t> out) const;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_FRAME_CRYPTO_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc
index c5fcd400f25..7ac8d5cf2e5 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_crypto_unittest.cc
@@ -10,8 +10,8 @@
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
TEST(FrameCryptoTest, EncryptsAndDecryptsFrames) {
@@ -76,5 +76,5 @@ TEST(FrameCryptoTest, EncryptsAndDecryptsFrames) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc b/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc
index bf061fe8e12..3dd4951852b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_id.cc
@@ -4,8 +4,8 @@
#include "cast/streaming/frame_id.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
std::ostream& operator<<(std::ostream& out, const FrameId rhs) {
out << "F";
@@ -14,5 +14,5 @@ std::ostream& operator<<(std::ostream& out, const FrameId rhs) {
return out << rhs.value();
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/frame_id.h b/chromium/third_party/openscreen/src/cast/streaming/frame_id.h
index 3625242257a..d741a289b6d 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/frame_id.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/frame_id.h
@@ -11,8 +11,8 @@
#include "cast/streaming/expanded_value_base.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Forward declaration (see below).
class FrameId;
@@ -94,6 +94,16 @@ class FrameId : public ExpandedValueBase<int64_t, FrameId> {
// The identifier for the first frame in a stream.
static constexpr FrameId first() { return FrameId(0); }
+ // A virtual identifier, representing the frame before the first. There should
+ // never actually be a frame streamed with this identifier. Instead, this is
+ // used in various components to represent a "not yet seen/processed the first
+ // frame" state.
+ //
+ // The name "leader" comes from the terminology used in tape reels, which
+ // refers to the non-data-carrying segment of tape before the recording
+ // begins.
+ static constexpr FrameId leader() { return FrameId(-1); }
+
private:
friend class ExpandedValueBase<int64_t, FrameId>;
friend std::ostream& operator<<(std::ostream& out, const FrameId rhs);
@@ -104,7 +114,7 @@ class FrameId : public ExpandedValueBase<int64_t, FrameId> {
constexpr int64_t value() const { return value_; }
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_FRAME_ID_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/message_port.h b/chromium/third_party/openscreen/src/cast/streaming/message_port.h
new file mode 100644
index 00000000000..f44a808dbf1
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/message_port.h
@@ -0,0 +1,38 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_MESSAGE_PORT_H_
+#define CAST_STREAMING_MESSAGE_PORT_H_
+
+#include "absl/strings/string_view.h"
+#include "platform/base/error.h"
+
+namespace openscreen {
+namespace cast {
+
+// This interface is intended to provide an abstraction for communicating
+// cast messages across a pipe with guaranteed delivery. This is used to
+// decouple the cast receiver session (and potentially other classes) from any
+// network implementation.
+class MessagePort {
+ public:
+ class Client {
+ public:
+ virtual void OnMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) = 0;
+ virtual void OnError(Error error) = 0;
+ };
+
+ virtual ~MessagePort() = default;
+ virtual void SetClient(Client* client) = 0;
+ virtual void PostMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) = 0;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_MESSAGE_PORT_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/message_util.h b/chromium/third_party/openscreen/src/cast/streaming/message_util.h
new file mode 100644
index 00000000000..fbb3b21199f
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/message_util.h
@@ -0,0 +1,77 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_MESSAGE_UTIL_H_
+#define CAST_STREAMING_MESSAGE_UTIL_H_
+
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "json/value.h"
+#include "platform/base/error.h"
+
+// This file contains helper methods that are used by both answer and offer
+// messages, but should not be publicly exposed/consumed.
+namespace openscreen {
+namespace cast {
+
+inline Error CreateParseError(const std::string& type) {
+ return Error(Error::Code::kJsonParseError, "Failed to parse " + type);
+}
+
+inline Error CreateParameterError(const std::string& type) {
+ return Error(Error::Code::kParameterInvalid, "Invalid parameter: " + type);
+}
+
+inline ErrorOr<bool> ParseBool(const Json::Value& parent,
+ const std::string& field) {
+ const Json::Value& value = parent[field];
+ if (!value.isBool()) {
+ return CreateParseError("bool field " + field);
+ }
+ return value.asBool();
+}
+
+inline ErrorOr<int> ParseInt(const Json::Value& parent,
+ const std::string& field) {
+ const Json::Value& value = parent[field];
+ if (!value.isInt()) {
+ return CreateParseError("integer field: " + field);
+ }
+ return value.asInt();
+}
+
+inline ErrorOr<uint32_t> ParseUint(const Json::Value& parent,
+ const std::string& field) {
+ const Json::Value& value = parent[field];
+ if (!value.isUInt()) {
+ return CreateParseError("unsigned integer field: " + field);
+ }
+ return value.asUInt();
+}
+
+inline ErrorOr<std::string> ParseString(const Json::Value& parent,
+ const std::string& field) {
+ const Json::Value& value = parent[field];
+ if (!value.isString()) {
+ return CreateParseError("string field: " + field);
+ }
+ return value.asString();
+}
+
+// TODO(jophba): refactor to be on ErrorOr itself.
+// Use this template for parsing only when there is a reasonable default
+// for the type you are using, e.g. int or std::string.
+template <typename T>
+T ValueOrDefault(const ErrorOr<T>& value, T fallback = T{}) {
+ if (value.is_value()) {
+ return value.value();
+ }
+ return fallback;
+}
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_MESSAGE_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h b/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h
index 0b16e599628..93d93f6762e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/mock_compound_rtcp_parser_client.h
@@ -8,13 +8,13 @@
#include "cast/streaming/compound_rtcp_parser.h"
#include "gmock/gmock.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
class MockCompoundRtcpParserClient : public CompoundRtcpParser::Client {
public:
MOCK_METHOD1(OnReceiverReferenceTimeAdvanced,
- void(openscreen::platform::Clock::time_point reference_time));
+ void(Clock::time_point reference_time));
MOCK_METHOD1(OnReceiverReport, void(const RtcpReportBlock& receiver_report));
MOCK_METHOD0(OnReceiverIndicatesPictureLoss, void());
MOCK_METHOD2(OnReceiverCheckpoint,
@@ -23,7 +23,7 @@ class MockCompoundRtcpParserClient : public CompoundRtcpParser::Client {
MOCK_METHOD1(OnReceiverIsMissingPackets, void(std::vector<PacketNack> nacks));
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_MOCK_COMPOUND_RTCP_PARSER_CLIENT_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc
new file mode 100644
index 00000000000..9d712537ccd
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc
@@ -0,0 +1,17 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/mock_environment.h"
+
+namespace openscreen {
+namespace cast {
+
+MockEnvironment::MockEnvironment(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner)
+ : Environment(now_function, task_runner) {}
+
+MockEnvironment::~MockEnvironment() = default;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h
new file mode 100644
index 00000000000..db66b141e31
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.h
@@ -0,0 +1,30 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_MOCK_ENVIRONMENT_H_
+#define CAST_STREAMING_MOCK_ENVIRONMENT_H_
+
+#include "cast/streaming/environment.h"
+#include "gmock/gmock.h"
+
+namespace openscreen {
+namespace cast {
+
+// An Environment that can intercept all packet sends, for unit testing.
+class MockEnvironment : public Environment {
+ public:
+ MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner);
+ ~MockEnvironment() override;
+
+ // Used to return fake values, to simulate a bound socket for testing.
+ MOCK_METHOD(IPEndpoint, GetBoundLocalEndpoint, (), (const, override));
+
+ // Used for intercepting packet sends from the implementation under test.
+ MOCK_METHOD(void, SendPacket, (absl::Span<const uint8_t> packet), (override));
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_MOCK_ENVIRONMENT_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc
index d7072baf03e..0d59c0a8d4b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc
@@ -6,11 +6,10 @@
#include "util/logging.h"
-using openscreen::platform::Clock;
using std::chrono::duration_cast;
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -54,5 +53,5 @@ Clock::time_point NtpTimeConverter::ToLocalTime(NtpTimestamp timestamp) const {
return seconds_since_start + remainder;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h
index f76dae4abe9..6b13132f8a9 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.h
@@ -9,8 +9,8 @@
#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// NTP timestamps are 64-bit timestamps that consist of two 32-bit parts: 1) The
// number of seconds since 1 January 1900; and 2) The fraction of the second,
@@ -41,24 +41,21 @@ constexpr NtpTimestamp AssembleNtpTimestamp(NtpSeconds seconds,
static_cast<uint32_t>(fraction.count());
}
-// Converts between openscreen::platform::Clock::time_points and NtpTimestamps.
-// The class is instantiated with the current openscreen::platform::Clock time
-// and the current wall clock time, and these are used to determine a fixed
-// origin reference point for all conversions. Thus, to avoid introducing
-// unintended timing-related behaviors, only one NtpTimeConverter instance
-// should be used for converting all the NTP timestamps in the same streaming
-// session.
+// Converts between Clock::time_points and NtpTimestamps. The class is
+// instantiated with the current Clock time and the current wall clock time, and
+// these are used to determine a fixed origin reference point for all
+// conversions. Thus, to avoid introducing unintended timing-related behaviors,
+// only one NtpTimeConverter instance should be used for converting all the NTP
+// timestamps in the same streaming session.
class NtpTimeConverter {
public:
- NtpTimeConverter(openscreen::platform::Clock::time_point now,
- std::chrono::seconds since_unix_epoch =
- openscreen::platform::GetWallTimeSinceUnixEpoch());
+ NtpTimeConverter(
+ Clock::time_point now,
+ std::chrono::seconds since_unix_epoch = GetWallTimeSinceUnixEpoch());
~NtpTimeConverter();
- NtpTimestamp ToNtpTimestamp(
- openscreen::platform::Clock::time_point time_point) const;
- openscreen::platform::Clock::time_point ToLocalTime(
- NtpTimestamp timestamp) const;
+ NtpTimestamp ToNtpTimestamp(Clock::time_point time_point) const;
+ Clock::time_point ToLocalTime(NtpTimestamp timestamp) const;
private:
// The time point on the platform clock's timeline that corresponds to
@@ -68,11 +65,11 @@ class NtpTimeConverter {
// can be off (with respect to each other) by even a large amount; and all
// that matters is that time ticks forward at a reasonable pace from some
// initial point.
- const openscreen::platform::Clock::time_point start_time_;
+ const Clock::time_point start_time_;
const NtpSeconds since_ntp_epoch_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_NTP_TIME_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc
index 552673c03c8..1caa81ad475 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc
@@ -6,13 +6,12 @@
#include "gtest/gtest.h"
-using openscreen::platform::Clock;
using std::chrono::duration_cast;
using std::chrono::microseconds;
using std::chrono::milliseconds;
+namespace openscreen {
namespace cast {
-namespace streaming {
TEST(NtpTimestampTest, SplitsIntoParts) {
// 1 Jan 1900.
@@ -73,8 +72,7 @@ TEST(NtpTimeConverterTest, ConvertsToNtpTimeAndBack) {
// our core assumptions (or the design) about the time math are wrong and
// should be looked into!
const Clock::time_point steady_clock_start = Clock::now();
- const std::chrono::seconds wall_clock_start =
- openscreen::platform::GetWallTimeSinceUnixEpoch();
+ const std::chrono::seconds wall_clock_start = GetWallTimeSinceUnixEpoch();
SCOPED_TRACE(::testing::Message()
<< "steady_clock_start.time_since_epoch().count() is "
<< steady_clock_start.time_since_epoch().count()
@@ -106,5 +104,5 @@ TEST(NtpTimeConverterTest, ConvertsToNtpTimeAndBack) {
}
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc
new file mode 100644
index 00000000000..f1da9c60eb3
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc
@@ -0,0 +1,437 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/offer_messages.h"
+
+#include <inttypes.h>
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "cast/streaming/constants.h"
+#include "cast/streaming/message_util.h"
+#include "cast/streaming/receiver_session.h"
+#include "platform/base/error.h"
+#include "util/big_endian.h"
+#include "util/json/json_serialization.h"
+#include "util/logging.h"
+#include "util/stringprintf.h"
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+constexpr char kSupportedStreams[] = "supportedStreams";
+constexpr char kAudioSourceType[] = "audio_source";
+constexpr char kVideoSourceType[] = "video_source";
+constexpr char kStreamType[] = "type";
+
+ErrorOr<RtpPayloadType> ParseRtpPayloadType(const Json::Value& parent,
+ const std::string& field) {
+ auto t = ParseInt(parent, field);
+ if (!t) {
+ return t.error();
+ }
+
+ uint8_t t_small = t.value();
+ if (t_small != t.value() || !IsRtpPayloadType(t_small)) {
+ return Error(Error::Code::kParameterInvalid,
+ "Received invalid RTP Payload Type.");
+ }
+
+ return static_cast<RtpPayloadType>(t_small);
+}
+
+ErrorOr<int> ParseRtpTimebase(const Json::Value& parent,
+ const std::string& field) {
+ auto error_or_raw = ParseString(parent, field);
+ if (!error_or_raw) {
+ return error_or_raw.error();
+ }
+
+ const auto fraction = SimpleFraction::FromString(error_or_raw.value());
+ if (fraction.is_error() || !fraction.value().is_positive()) {
+ return CreateParseError("RTP timebase");
+ }
+ // The spec demands a leading 1, so this isn't really a fraction.
+ OSP_DCHECK(fraction.value().numerator == 1);
+ return fraction.value().denominator;
+}
+
+// For a hex byte, the conversion is 4 bits to 1 character, e.g.
+// 0b11110001 becomes F1, so 1 byte is two characters.
+constexpr int kHexDigitsPerByte = 2;
+constexpr int kAesBytesSize = 16;
+constexpr int kAesStringLength = kAesBytesSize * kHexDigitsPerByte;
+ErrorOr<std::array<uint8_t, kAesBytesSize>> ParseAesHexBytes(
+ const Json::Value& parent,
+ const std::string& field) {
+ auto hex_string = ParseString(parent, field);
+ if (!hex_string) {
+ return hex_string.error();
+ }
+
+ constexpr int kHexDigitsPerScanField = 16;
+ constexpr int kNumScanFields = kAesStringLength / kHexDigitsPerScanField;
+ uint64_t quads[kNumScanFields];
+ int chars_scanned;
+ if (hex_string.value().size() == kAesStringLength &&
+ sscanf(hex_string.value().c_str(), "%16" SCNx64 "%16" SCNx64 "%n",
+ &quads[0], &quads[1], &chars_scanned) == kNumScanFields &&
+ chars_scanned == kAesStringLength &&
+ std::none_of(hex_string.value().begin(), hex_string.value().end(),
+ [](char c) { return std::isspace(c); })) {
+ std::array<uint8_t, kAesBytesSize> bytes;
+ WriteBigEndian(quads[0], bytes.data());
+ WriteBigEndian(quads[1], bytes.data() + 8);
+ return bytes;
+ }
+ return CreateParseError("AES hex string bytes");
+}
+
+ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) {
+ auto index = ParseInt(value, "index");
+ if (!index) {
+ return index.error();
+ }
+ // If channel is omitted, the default value is used later.
+ auto channels = ParseInt(value, "channels");
+ if (channels.is_value() && channels.value() <= 0) {
+ return CreateParameterError("channel");
+ }
+ auto codec_name = ParseString(value, "codecName");
+ if (!codec_name) {
+ return codec_name.error();
+ }
+ auto rtp_profile = ParseString(value, "rtpProfile");
+ if (!rtp_profile) {
+ return rtp_profile.error();
+ }
+ auto rtp_payload_type = ParseRtpPayloadType(value, "rtpPayloadType");
+ if (!rtp_payload_type) {
+ return rtp_payload_type.error();
+ }
+ auto ssrc = ParseUint(value, "ssrc");
+ if (!ssrc) {
+ return ssrc.error();
+ }
+ auto aes_key = ParseAesHexBytes(value, "aesKey");
+ if (!aes_key) {
+ return aes_key.error();
+ }
+ auto aes_iv_mask = ParseAesHexBytes(value, "aesIvMask");
+ if (!aes_iv_mask) {
+ return aes_iv_mask.error();
+ }
+ auto rtp_timebase = ParseRtpTimebase(value, "timeBase");
+ if (!rtp_timebase) {
+ return rtp_timebase.error();
+ }
+
+ auto target_delay = ParseInt(value, "targetDelay");
+ std::chrono::milliseconds target_delay_ms = kDefaultTargetPlayoutDelay;
+ if (target_delay) {
+ auto d = std::chrono::milliseconds(target_delay.value());
+ if (d >= kMinTargetPlayoutDelay && d <= kMaxTargetPlayoutDelay) {
+ target_delay_ms = d;
+ } else {
+ return CreateParameterError("target delay");
+ }
+ }
+
+ auto receiver_rtcp_event_log = ParseBool(value, "receiverRtcpEventLog");
+ auto receiver_rtcp_dscp = ParseString(value, "receiverRtcpDscp");
+ return Stream{index.value(),
+ type,
+ ValueOrDefault(channels, type == Stream::Type::kAudioSource
+ ? kDefaultNumAudioChannels
+ : kDefaultNumVideoChannels),
+ codec_name.value(),
+ rtp_payload_type.value(),
+ ssrc.value(),
+ target_delay_ms,
+ aes_key.value(),
+ aes_iv_mask.value(),
+ ValueOrDefault(receiver_rtcp_event_log),
+ ValueOrDefault(receiver_rtcp_dscp),
+ rtp_timebase.value()};
+}
+
+ErrorOr<AudioStream> ParseAudioStream(const Json::Value& value) {
+ auto stream = ParseStream(value, Stream::Type::kAudioSource);
+ if (!stream) {
+ return stream.error();
+ }
+ auto bit_rate = ParseInt(value, "bitRate");
+ if (!bit_rate) {
+ return bit_rate.error();
+ }
+ // A bit rate of 0 is valid for some codec types, so we don't enforce here.
+ if (bit_rate.value() < 0) {
+ return CreateParameterError("bit rate");
+ }
+ return AudioStream{stream.value(), bit_rate.value()};
+}
+
+ErrorOr<Resolution> ParseResolution(const Json::Value& value) {
+ auto width = ParseInt(value, "width");
+ if (!width) {
+ return width.error();
+ }
+ auto height = ParseInt(value, "height");
+ if (!height) {
+ return height.error();
+ }
+ if (width.value() <= 0 || height.value() <= 0) {
+ return CreateParameterError("resolution");
+ }
+ return Resolution{width.value(), height.value()};
+}
+
+ErrorOr<std::vector<Resolution>> ParseResolutions(const Json::Value& parent,
+ const std::string& field) {
+ std::vector<Resolution> resolutions;
+ // Some legacy senders don't provide resolutions, so just return empty.
+ const Json::Value& value = parent[field];
+ if (!value.isArray() || value.empty()) {
+ return resolutions;
+ }
+
+ for (Json::ArrayIndex i = 0; i < value.size(); ++i) {
+ auto r = ParseResolution(value[i]);
+ if (!r) {
+ return r.error();
+ }
+ resolutions.push_back(r.value());
+ }
+
+ return resolutions;
+}
+
+ErrorOr<VideoStream> ParseVideoStream(const Json::Value& value) {
+ auto stream = ParseStream(value, Stream::Type::kVideoSource);
+ if (!stream) {
+ return stream.error();
+ }
+ auto resolutions = ParseResolutions(value, "resolutions");
+ if (!resolutions) {
+ return resolutions.error();
+ }
+
+ auto raw_max_frame_rate = ParseString(value, "maxFrameRate");
+ SimpleFraction max_frame_rate{kDefaultMaxFrameRate, 1};
+ if (raw_max_frame_rate.is_value()) {
+ auto parsed = SimpleFraction::FromString(raw_max_frame_rate.value());
+ if (parsed.is_value() && parsed.value().is_positive()) {
+ max_frame_rate = parsed.value();
+ }
+ }
+
+ auto profile = ParseString(value, "profile");
+ auto protection = ParseString(value, "protection");
+ auto max_bit_rate = ParseInt(value, "maxBitRate");
+ auto level = ParseString(value, "level");
+ auto error_recovery_mode = ParseString(value, "errorRecoveryMode");
+ return VideoStream{stream.value(),
+ max_frame_rate,
+ ValueOrDefault(max_bit_rate, 4 << 20),
+ ValueOrDefault(protection),
+ ValueOrDefault(profile),
+ ValueOrDefault(level),
+ resolutions.value(),
+ ValueOrDefault(error_recovery_mode)};
+}
+
+absl::string_view ToString(Stream::Type type) {
+ switch (type) {
+ case Stream::Type::kAudioSource:
+ return kAudioSourceType;
+ case Stream::Type::kVideoSource:
+ return kVideoSourceType;
+ default: {
+ OSP_NOTREACHED();
+ return "";
+ }
+ }
+}
+
+} // namespace
+
+constexpr char kCastMirroring[] = "mirroring";
+constexpr char kCastRemoting[] = "remoting";
+
+// static
+CastMode CastMode::Parse(absl::string_view value) {
+ return (value == kCastRemoting) ? CastMode{CastMode::Type::kRemoting}
+ : CastMode{CastMode::Type::kMirroring};
+}
+
+ErrorOr<Json::Value> Stream::ToJson() const {
+ if (channels < 1 || index < 0 || codec_name.empty() ||
+ target_delay.count() <= 0 ||
+ target_delay.count() > std::numeric_limits<int>::max() ||
+ rtp_timebase < 1) {
+ return CreateParameterError("Stream");
+ }
+
+ Json::Value root;
+ root["index"] = index;
+ root["type"] = std::string(ToString(type));
+ root["channels"] = channels;
+ root["codecName"] = codec_name;
+ root["rtpPayloadType"] = static_cast<int>(rtp_payload_type);
+ // rtpProfile is technically required by the spec, although it is always set
+ // to cast. We set it here to be compliant with all spec implementers.
+ root["rtpProfile"] = "cast";
+ static_assert(sizeof(ssrc) <= sizeof(Json::UInt),
+ "this code assumes Ssrc fits in a Json::UInt");
+ root["ssrc"] = static_cast<Json::UInt>(ssrc);
+ root["targetDelay"] = static_cast<int>(target_delay.count());
+ root["aesKey"] = HexEncode(aes_key);
+ root["aesIvMask"] = HexEncode(aes_iv_mask);
+ root["ReceiverRtcpEventLog"] = receiver_rtcp_event_log;
+ root["receiverRtcpDscp"] = receiver_rtcp_dscp;
+ root["timeBase"] = "1/" + std::to_string(rtp_timebase);
+ return root;
+}
+
+std::string CastMode::ToString() const {
+ switch (type) {
+ case Type::kMirroring:
+ return kCastMirroring;
+ case Type::kRemoting:
+ return kCastRemoting;
+ default:
+ OSP_NOTREACHED();
+ return "";
+ }
+}
+
+ErrorOr<Json::Value> AudioStream::ToJson() const {
+ // A bit rate of 0 is valid for some codec types, so we don't enforce here.
+ if (bit_rate < 0) {
+ return CreateParameterError("AudioStream");
+ }
+
+ auto error_or_stream = stream.ToJson();
+ if (error_or_stream.is_error()) {
+ return error_or_stream;
+ }
+
+ error_or_stream.value()["bitRate"] = bit_rate;
+ return error_or_stream;
+}
+
+ErrorOr<Json::Value> Resolution::ToJson() const {
+ if (width <= 0 || height <= 0) {
+ return CreateParameterError("Resolution");
+ }
+
+ Json::Value root;
+ root["width"] = width;
+ root["height"] = height;
+ return root;
+}
+
+ErrorOr<Json::Value> VideoStream::ToJson() const {
+ if (max_bit_rate <= 0 || !max_frame_rate.is_positive()) {
+ return CreateParameterError("VideoStream");
+ }
+
+ auto error_or_stream = stream.ToJson();
+ if (error_or_stream.is_error()) {
+ return error_or_stream;
+ }
+
+ auto& stream = error_or_stream.value();
+ stream["maxFrameRate"] = max_frame_rate.ToString();
+ stream["maxBitRate"] = max_bit_rate;
+ stream["protection"] = protection;
+ stream["profile"] = profile;
+ stream["level"] = level;
+ stream["errorRecoveryMode"] = error_recovery_mode;
+
+ Json::Value rs;
+ for (auto resolution : resolutions) {
+ auto eoj = resolution.ToJson();
+ if (eoj.is_error()) {
+ return eoj;
+ }
+ rs.append(eoj.value());
+ }
+ stream["resolutions"] = std::move(rs);
+ return error_or_stream;
+}
+
+// static
+ErrorOr<Offer> Offer::Parse(const Json::Value& root) {
+ CastMode cast_mode = CastMode::Parse(root["castMode"].asString());
+
+ const ErrorOr<bool> get_status = ParseBool(root, "receiverGetStatus");
+
+ Json::Value supported_streams = root[kSupportedStreams];
+ if (!supported_streams.isArray()) {
+ return CreateParseError("supported streams in offer");
+ }
+
+ std::vector<AudioStream> audio_streams;
+ std::vector<VideoStream> video_streams;
+ for (Json::ArrayIndex i = 0; i < supported_streams.size(); ++i) {
+ const Json::Value& fields = supported_streams[i];
+ auto type = ParseString(fields, kStreamType);
+ if (!type) {
+ return type.error();
+ }
+
+ if (type.value() == kAudioSourceType) {
+ auto stream = ParseAudioStream(fields);
+ if (!stream) {
+ return stream.error();
+ }
+ audio_streams.push_back(std::move(stream.value()));
+ } else if (type.value() == kVideoSourceType) {
+ auto stream = ParseVideoStream(fields);
+ if (!stream) {
+ return stream.error();
+ }
+ video_streams.push_back(std::move(stream.value()));
+ }
+ }
+
+ return Offer{cast_mode, ValueOrDefault(get_status), std::move(audio_streams),
+ std::move(video_streams)};
+}
+
+ErrorOr<Json::Value> Offer::ToJson() const {
+ Json::Value root;
+ root["castMode"] = cast_mode.ToString();
+ root["receiverGetStatus"] = supports_wifi_status_reporting;
+
+ Json::Value streams;
+ for (auto& as : audio_streams) {
+ auto eoj = as.ToJson();
+ if (eoj.is_error()) {
+ return eoj;
+ }
+ streams.append(eoj.value());
+ }
+
+ for (auto& vs : video_streams) {
+ auto eoj = vs.ToJson();
+ if (eoj.is_error()) {
+ return eoj;
+ }
+ streams.append(eoj.value());
+ }
+
+ root[kSupportedStreams] = std::move(streams);
+ return root;
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h
new file mode 100644
index 00000000000..319145bc3cf
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h
@@ -0,0 +1,121 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_OFFER_MESSAGES_H_
+#define CAST_STREAMING_OFFER_MESSAGES_H_
+
+#include <chrono> // NOLINT
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "cast/streaming/rtp_defines.h"
+#include "cast/streaming/session_config.h"
+#include "json/value.h"
+#include "platform/base/error.h"
+#include "util/simple_fraction.h"
+
+// This file contains the implementation of the Cast V2 Mirroring Control
+// Protocol offer object definition.
+namespace openscreen {
+namespace cast {
+
+// If the target delay provided by the sender is not bounded by
+// [kMinTargetDelay, kMaxTargetDelay], it will be set to
+// kDefaultTargetPlayoutDelay.
+constexpr auto kMinTargetPlayoutDelay = std::chrono::milliseconds(0);
+constexpr auto kMaxTargetPlayoutDelay = std::chrono::milliseconds(2000);
+
+// If the sender provides an invalid maximum frame rate, it will
+// be set to kDefaultMaxFrameRate.
+constexpr int kDefaultMaxFrameRate = 30;
+
+constexpr int kDefaultNumVideoChannels = 1;
+constexpr int kDefaultNumAudioChannels = 2;
+
+// A stream, as detailed by the CastV2 protocol spec, is a segment of an
+// offer message specifically representing a configuration object for
+// a codec and its related fields, such as maximum bit rate, time base,
+// and other fields.
+// Composed classes include AudioStream and VideoStream, which contain
+// fields specific to audio and video respectively.
+struct Stream {
+ enum class Type : uint8_t { kAudioSource, kVideoSource };
+
+ ErrorOr<Json::Value> ToJson() const;
+
+ int index = 0;
+ Type type = {};
+
+ // Default channel count is 1, e.g. for video.
+ int channels = 0;
+ std::string codec_name = {};
+ RtpPayloadType rtp_payload_type = {};
+ Ssrc ssrc = {};
+ std::chrono::milliseconds target_delay = {};
+
+ // AES Key and IV mask format is very strict: a 32 digit hex string that
+ // must be converted to a 16 digit byte array.
+ std::array<uint8_t, 16> aes_key = {};
+ std::array<uint8_t, 16> aes_iv_mask = {};
+ bool receiver_rtcp_event_log = {};
+ std::string receiver_rtcp_dscp = {};
+ int rtp_timebase = 0;
+};
+
+struct AudioStream {
+ ErrorOr<Json::Value> ToJson() const;
+
+ Stream stream = {};
+ int bit_rate = 0;
+};
+
+struct Resolution {
+ ErrorOr<Json::Value> ToJson() const;
+
+ int width = 0;
+ int height = 0;
+};
+
+struct VideoStream {
+ ErrorOr<Json::Value> ToJson() const;
+
+ Stream stream = {};
+ SimpleFraction max_frame_rate;
+ int max_bit_rate = 0;
+ std::string protection = {};
+ std::string profile = {};
+ std::string level = {};
+ std::vector<Resolution> resolutions = {};
+ std::string error_recovery_mode = {};
+};
+
+struct CastMode {
+ public:
+ enum class Type : uint8_t { kMirroring, kRemoting };
+
+ static CastMode Parse(absl::string_view value);
+ std::string ToString() const;
+
+ // Default cast mode is mirroring.
+ Type type = Type::kMirroring;
+};
+
+struct Offer {
+ static ErrorOr<Offer> Parse(const Json::Value& root);
+ ErrorOr<Json::Value> ToJson() const;
+
+ CastMode cast_mode = {};
+ // This field is poorly named in the spec (receiverGetStatus), so we use
+ // a more descriptive name here.
+ bool supports_wifi_status_reporting = {};
+ std::vector<AudioStream> audio_streams = {};
+ std::vector<VideoStream> video_streams = {};
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_OFFER_MESSAGES_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc
new file mode 100644
index 00000000000..4931927a5ae
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages_unittest.cc
@@ -0,0 +1,429 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/offer_messages.h"
+
+#include <utility>
+
+#include "cast/streaming/rtp_defines.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "util/json/json_serialization.h"
+
+using ::testing::ElementsAre;
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+constexpr char kValidOffer[] = R"({
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": [
+ {
+ "index": 0,
+ "type": "video_source",
+ "codecName": "h264",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 101,
+ "ssrc": 19088743,
+ "maxFrameRate": "60000/1000",
+ "timeBase": "1/90000",
+ "maxBitRate": 5000000,
+ "profile": "main",
+ "level": "4",
+ "targetDelay": 200,
+ "aesKey": "040d756791711fd3adb939066e6d8690",
+ "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed",
+ "resolutions": [
+ {
+ "width": 1280,
+ "height": 720
+ },
+ {
+ "width": 640,
+ "height": 360
+ },
+ {
+ "width": 640,
+ "height": 480
+ }
+ ]
+ },
+ {
+ "index": 1,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 100,
+ "ssrc": 19088744,
+ "maxFrameRate": "30000/1001",
+ "targetDelay": 1000,
+ "timeBase": "1/90000",
+ "maxBitRate": 5000000,
+ "profile": "main",
+ "level": "5",
+ "aesKey": "bbf109bf84513b456b13a184453b66ce",
+ "aesIvMask": "edaf9e4536e2b66191f560d9c04b2a69"
+ },
+ {
+ "index": 2,
+ "type": "audio_source",
+ "codecName": "opus",
+ "targetDelay": 300,
+ "rtpProfile": "cast",
+ "rtpPayloadType": 96,
+ "ssrc": 4294967295,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }
+ ]
+})";
+
+void ExpectFailureOnParse(absl::string_view body) {
+ ErrorOr<Json::Value> root = json::Parse(body);
+ ASSERT_TRUE(root.is_value());
+ EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_error());
+}
+
+void ExpectEqualsValidOffer(const Offer& offer) {
+ EXPECT_EQ(CastMode::Type::kMirroring, offer.cast_mode.type);
+ EXPECT_EQ(true, offer.supports_wifi_status_reporting);
+
+ // Verify list of video streams.
+ EXPECT_EQ(2u, offer.video_streams.size());
+ const auto& video_streams = offer.video_streams;
+
+ const bool flipped = video_streams[0].stream.index != 0;
+ const VideoStream& vs_one = flipped ? video_streams[1] : video_streams[0];
+ const VideoStream& vs_two = flipped ? video_streams[0] : video_streams[1];
+
+ EXPECT_EQ(0, vs_one.stream.index);
+ EXPECT_EQ(1, vs_one.stream.channels);
+ EXPECT_EQ(Stream::Type::kVideoSource, vs_one.stream.type);
+ EXPECT_EQ("h264", vs_one.stream.codec_name);
+ EXPECT_EQ(RtpPayloadType::kVideoH264, vs_one.stream.rtp_payload_type);
+ EXPECT_EQ(19088743u, vs_one.stream.ssrc);
+ EXPECT_EQ((SimpleFraction{60000, 1000}), vs_one.max_frame_rate);
+ EXPECT_EQ(90000, vs_one.stream.rtp_timebase);
+ EXPECT_EQ(5000000, vs_one.max_bit_rate);
+ EXPECT_EQ("main", vs_one.profile);
+ EXPECT_EQ("4", vs_one.level);
+ EXPECT_THAT(vs_one.stream.aes_key,
+ ElementsAre(0x04, 0x0d, 0x75, 0x67, 0x91, 0x71, 0x1f, 0xd3, 0xad,
+ 0xb9, 0x39, 0x06, 0x6e, 0x6d, 0x86, 0x90));
+ EXPECT_THAT(vs_one.stream.aes_iv_mask,
+ ElementsAre(0x9f, 0xf0, 0xf0, 0x22, 0xa9, 0x59, 0x15, 0x0e, 0x70,
+ 0xa2, 0xd0, 0x5a, 0x6c, 0x18, 0x4a, 0xed));
+
+ const auto& resolutions = vs_one.resolutions;
+ EXPECT_EQ(3u, resolutions.size());
+ const Resolution& r_one = resolutions[0];
+ EXPECT_EQ(1280, r_one.width);
+ EXPECT_EQ(720, r_one.height);
+
+ const Resolution& r_two = resolutions[1];
+ EXPECT_EQ(640, r_two.width);
+ EXPECT_EQ(360, r_two.height);
+
+ const Resolution& r_three = resolutions[2];
+ EXPECT_EQ(640, r_three.width);
+ EXPECT_EQ(480, r_three.height);
+
+ EXPECT_EQ(1, vs_two.stream.index);
+ EXPECT_EQ(1, vs_two.stream.channels);
+ EXPECT_EQ(Stream::Type::kVideoSource, vs_two.stream.type);
+ EXPECT_EQ("vp8", vs_two.stream.codec_name);
+ EXPECT_EQ(RtpPayloadType::kVideoVp8, vs_two.stream.rtp_payload_type);
+ EXPECT_EQ(19088744u, vs_two.stream.ssrc);
+ EXPECT_EQ((SimpleFraction{30000, 1001}), vs_two.max_frame_rate);
+ EXPECT_EQ(90000, vs_two.stream.rtp_timebase);
+ EXPECT_EQ(5000000, vs_two.max_bit_rate);
+ EXPECT_EQ("main", vs_two.profile);
+ EXPECT_EQ("5", vs_two.level);
+ EXPECT_THAT(vs_two.stream.aes_key,
+ ElementsAre(0xbb, 0xf1, 0x09, 0xbf, 0x84, 0x51, 0x3b, 0x45, 0x6b,
+ 0x13, 0xa1, 0x84, 0x45, 0x3b, 0x66, 0xce));
+ EXPECT_THAT(vs_two.stream.aes_iv_mask,
+ ElementsAre(0xed, 0xaf, 0x9e, 0x45, 0x36, 0xe2, 0xb6, 0x61, 0x91,
+ 0xf5, 0x60, 0xd9, 0xc0, 0x4b, 0x2a, 0x69));
+
+ const auto& resolutions_two = vs_two.resolutions;
+ EXPECT_EQ(0u, resolutions_two.size());
+
+ // Verify list of audio streams.
+ EXPECT_EQ(1u, offer.audio_streams.size());
+ const AudioStream& as = offer.audio_streams[0];
+ EXPECT_EQ(2, as.stream.index);
+ EXPECT_EQ(Stream::Type::kAudioSource, as.stream.type);
+ EXPECT_EQ("opus", as.stream.codec_name);
+ EXPECT_EQ(RtpPayloadType::kAudioOpus, as.stream.rtp_payload_type);
+ EXPECT_EQ(std::numeric_limits<Ssrc>::max(), as.stream.ssrc);
+ EXPECT_EQ(124000, as.bit_rate);
+ EXPECT_EQ(2, as.stream.channels);
+
+ EXPECT_THAT(as.stream.aes_key,
+ ElementsAre(0x51, 0x02, 0x7e, 0x4e, 0x23, 0x47, 0xcb, 0xcb, 0x49,
+ 0xd5, 0x7e, 0xf1, 0x01, 0x77, 0xae, 0xbc));
+ EXPECT_THAT(as.stream.aes_iv_mask,
+ ElementsAre(0x7f, 0x12, 0xa1, 0x9b, 0xe6, 0x2a, 0x36, 0xc0, 0x4a,
+ 0xe4, 0x11, 0x6c, 0xaa, 0xef, 0xf6, 0xd1));
+}
+
+} // namespace
+
+TEST(OfferTest, ErrorOnEmptyOffer) {
+ ExpectFailureOnParse("{}");
+}
+
+TEST(OfferTest, ErrorOnMissingMandatoryFields) {
+ // It's okay if castMode is omitted, but if supportedStreams isanne //
+ // omitted we should fail here.
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring"
+ })");
+}
+
+TEST(OfferTest, CanParseValidButStreamlessOffer) {
+ ErrorOr<Json::Value> root = json::Parse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": []
+ })");
+ ASSERT_TRUE(root.is_value());
+ EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value());
+}
+
+TEST(OfferTest, ErrorOnMissingAudioStreamMandatoryField) {
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 96,
+ "ssrc": 19088743,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2
+ }]})");
+
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "audio_source",
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 96,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2
+ }]})");
+}
+
+TEST(OfferTest, CanParseValidButMinimalAudioOffer) {
+ ErrorOr<Json::Value> root = json::Parse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "audio_source",
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 96,
+ "ssrc": 19088743,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }]
+ })");
+ ASSERT_TRUE(root.is_value());
+ EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value());
+}
+
+TEST(OfferTest, CanParseValidZeroBitRateAudioOffer) {
+ ErrorOr<Json::Value> root = json::Parse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "audio_source",
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 96,
+ "ssrc": 19088743,
+ "bitRate": 0,
+ "timeBase": "1/96000",
+ "channels": 5,
+ "aesKey": "51029e4e2347cbcb49d57ef10177aebd",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff5d2"
+ }]
+ })");
+ ASSERT_TRUE(root.is_value());
+ EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value());
+}
+
+TEST(OfferTest, ErrorOnMissingVideoStreamMandatoryField) {
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "codecName": "video_source",
+ "rtpProfile": "h264",
+ "rtpPayloadType": 101,
+ "ssrc": 19088743,
+ "bitRate": 124000,
+ "timeBase": "1/48000"
+ }]
+ })");
+
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "video_source",
+ "codecName": "h264",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 101,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "maxBitRate": 10000
+ }]
+ })");
+
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 100,
+ "ssrc": 19088743,
+ "timeBase": "1/48000",
+ "resolutions": [],
+ "maxBitRate": 10000
+ }]
+ })");
+
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 100,
+ "ssrc": 19088743,
+ "timeBase": "1/48000",
+ "resolutions": [],
+ "maxBitRate": 10000,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc"
+ }]
+ })");
+
+ ExpectFailureOnParse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 100,
+ "ssrc": 19088743,
+ "timeBase": "1/48000",
+ "resolutions": [],
+ "maxBitRate": 10000,
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }]
+ })");
+}
+
+TEST(OfferTest, CanParseValidButMinimalVideoOffer) {
+ ErrorOr<Json::Value> root = json::Parse(R"({
+ "castMode": "mirroring",
+ "supportedStreams": [{
+ "index": 2,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 100,
+ "ssrc": 19088743,
+ "timeBase": "1/48000",
+ "resolutions": [],
+ "maxBitRate": 10000,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }]
+ })");
+
+ ASSERT_TRUE(root.is_value());
+ EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value());
+}
+
+TEST(OfferTest, CanParseValidOffer) {
+ ErrorOr<Json::Value> root = json::Parse(kValidOffer);
+ ASSERT_TRUE(root.is_value());
+ ErrorOr<Offer> offer = Offer::Parse(std::move(root.value()));
+
+ ExpectEqualsValidOffer(offer.value());
+}
+
+TEST(OfferTest, ParseAndToJsonResultsInSameOffer) {
+ ErrorOr<Json::Value> root = json::Parse(kValidOffer);
+ ASSERT_TRUE(root.is_value());
+ ErrorOr<Offer> offer = Offer::Parse(std::move(root.value()));
+
+ ExpectEqualsValidOffer(offer.value());
+
+ auto eoj = offer.value().ToJson();
+ EXPECT_TRUE(eoj.is_value());
+ ErrorOr<Offer> reparsed_offer = Offer::Parse(std::move(eoj.value()));
+ ExpectEqualsValidOffer(reparsed_offer.value());
+}
+
+// We don't want to enforce that a given offer must have both audio and
+// video, so we don't assert on either.
+TEST(OfferTest, ToJsonSucceedsWithMissingStreams) {
+ ErrorOr<Json::Value> root = json::Parse(kValidOffer);
+ ASSERT_TRUE(root.is_value());
+ ErrorOr<Offer> offer = Offer::Parse(std::move(root.value()));
+ ExpectEqualsValidOffer(offer.value());
+ const Offer valid_offer = std::move(offer.value());
+
+ Offer missing_audio_streams = valid_offer;
+ missing_audio_streams.audio_streams.clear();
+ EXPECT_TRUE(missing_audio_streams.ToJson().is_value());
+
+ Offer missing_video_streams = valid_offer;
+ missing_video_streams.audio_streams.clear();
+ EXPECT_TRUE(missing_video_streams.ToJson().is_value());
+}
+
+TEST(OfferTest, ToJsonFailsWithInvalidStreams) {
+ ErrorOr<Json::Value> root = json::Parse(kValidOffer);
+ ASSERT_TRUE(root.is_value());
+ ErrorOr<Offer> offer = Offer::Parse(std::move(root.value()));
+ ExpectEqualsValidOffer(offer.value());
+ const Offer valid_offer = std::move(offer.value());
+
+ Offer video_stream_invalid = valid_offer;
+ video_stream_invalid.video_streams[0].max_frame_rate.denominator = 0;
+ EXPECT_TRUE(video_stream_invalid.ToJson().is_error());
+
+ Offer audio_stream_invalid = valid_offer;
+ video_stream_invalid.audio_streams[0].bit_rate = 0;
+ EXPECT_TRUE(video_stream_invalid.ToJson().is_error());
+
+ Offer stream_invalid = valid_offer;
+ stream_invalid.video_streams[0].stream.codec_name = "";
+ EXPECT_TRUE(stream_invalid.ToJson().is_error());
+}
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc
index 0d7f268e724..23a5e715841 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.cc
@@ -6,10 +6,8 @@
#include <algorithm>
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
PacketReceiveStatsTracker::PacketReceiveStatsTracker(int rtp_timebase)
: rtp_timebase_(rtp_timebase) {}
@@ -77,5 +75,5 @@ void PacketReceiveStatsTracker::PopulateNextReport(RtcpReportBlock* report) {
report->jitter = RtpTimeDelta::FromDuration(jitter_, rtp_timebase_);
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h
index e579f76feeb..8d8e23c9fe2 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker.h
@@ -12,8 +12,8 @@
#include "cast/streaming/rtp_time.h"
#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Maintains statistics for RTP packet arrival timing, jitter, and loss rates;
// and then uses these to compute and set the related fields in a RTCP Receiver
@@ -29,10 +29,9 @@ class PacketReceiveStatsTracker {
// RtpPacketParser::ParseResult. |arrival_time| is when the packet was
// received (i.e., right-off the network socket, before any
// processing/parsing).
- void OnReceivedValidRtpPacket(
- uint16_t sequence_number,
- RtpTimeTicks rtp_timestamp,
- openscreen::platform::Clock::time_point arrival_time);
+ void OnReceivedValidRtpPacket(uint16_t sequence_number,
+ RtpTimeTicks rtp_timestamp,
+ Clock::time_point arrival_time);
// Populates *only* those fields in the given |report| that pertain to packet
// loss, jitter, and the latest-known RTP packet sequence number.
@@ -89,7 +88,7 @@ class PacketReceiveStatsTracker {
// The time the last RTP packet was received. This is used in the computation
// that updates |jitter_|.
- openscreen::platform::Clock::time_point last_rtp_packet_arrival_time_;
+ Clock::time_point last_rtp_packet_arrival_time_;
// The RTP timestamp of the last RTP packet received. This is used in the
// computation that updates |jitter_|.
@@ -98,10 +97,10 @@ class PacketReceiveStatsTracker {
// The interarrival jitter. See RFC 3550 spec, section 6.4.1. The Cast
// Streaming spec diverges from the algorithm in the RFC spec in that it uses
// different pieces of timing data to calculate this metric.
- openscreen::platform::Clock::duration jitter_;
+ Clock::duration jitter_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_PACKET_RECEIVE_STATS_TRACKER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc
index 1532741f953..5146cc8d302 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc
@@ -9,10 +9,8 @@
#include "cast/streaming/constants.h"
#include "gtest/gtest.h"
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
constexpr int kSomeRtpTimebase = static_cast<int>(kVideoTimebase::den);
@@ -206,5 +204,5 @@ TEST(PacketReceiveStatsTrackerTest, ComputesJitterCorrectly) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc
index cf43ac14142..7e789808e0f 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util.cc
@@ -7,10 +7,8 @@
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtp_defines.h"
-using openscreen::ReadBigEndian;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting(
absl::Span<const uint8_t> packet) {
@@ -40,5 +38,5 @@ std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting(
return std::make_pair(ApparentPacketType::UNKNOWN, Ssrc{0});
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util.h b/chromium/third_party/openscreen/src/cast/streaming/packet_util.h
index 869b917eebf..ba84ff719e0 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_util.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util.h
@@ -11,14 +11,14 @@
#include "cast/streaming/ssrc.h"
#include "util/big_endian.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Reads a field from the start of the given span and advances the span to point
// just after the field.
template <typename Integer>
inline Integer ConsumeField(absl::Span<const uint8_t>* in) {
- const Integer result = openscreen::ReadBigEndian<Integer>(in->data());
+ const Integer result = ReadBigEndian<Integer>(in->data());
in->remove_prefix(sizeof(Integer));
return result;
}
@@ -27,7 +27,7 @@ inline Integer ConsumeField(absl::Span<const uint8_t>* in) {
// just after the field.
template <typename Integer>
inline void AppendField(Integer value, absl::Span<uint8_t>* out) {
- openscreen::WriteBigEndian<Integer>(value, out->data());
+ WriteBigEndian<Integer>(value, out->data());
out->remove_prefix(sizeof(Integer));
}
@@ -56,7 +56,7 @@ enum class ApparentPacketType { UNKNOWN, RTP, RTCP };
std::pair<ApparentPacketType, Ssrc> InspectPacketForRouting(
absl::Span<const uint8_t> packet);
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_PACKET_UTIL_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc
index 1f1f095d76c..82e2a40871b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/packet_util_unittest.cc
@@ -7,8 +7,8 @@
#include "absl/types/span.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
// Tests that a simple RTCP packet containing only a Sender Report can be
@@ -181,5 +181,5 @@ TEST(PacketUtilTest, InspectsGarbagePacket) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc
index 73a37eb24e3..b24b61eec3d 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc
@@ -5,22 +5,20 @@
#include "cast/streaming/receiver.h"
#include <algorithm>
-#include <functional>
#include "absl/types/span.h"
#include "cast/streaming/constants.h"
#include "cast/streaming/receiver_packet_router.h"
+#include "cast/streaming/session_config.h"
#include "util/logging.h"
#include "util/std_util.h"
-using openscreen::platform::Clock;
-
using std::chrono::duration_cast;
using std::chrono::microseconds;
using std::chrono::milliseconds;
+namespace openscreen {
namespace cast {
-namespace streaming {
// Conveniences for ensuring logging output includes the SSRC of the Receiver,
// to help distinguish one out of multiple instances in a Cast Streaming
@@ -33,8 +31,7 @@ namespace streaming {
Receiver::Receiver(Environment* environment,
ReceiverPacketRouter* packet_router,
- const cast::streaming::SessionConfig& config,
- std::chrono::milliseconds initial_target_playout_delay)
+ const SessionConfig& config)
: now_(environment->now_function()),
packet_router_(packet_router),
rtcp_session_(config.sender_ssrc, config.receiver_ssrc, now_()),
@@ -51,13 +48,13 @@ Receiver::Receiver(Environment* environment,
consumption_alarm_(environment->now_function(),
environment->task_runner()) {
OSP_DCHECK(packet_router_);
- OSP_DCHECK_EQ(checkpoint_frame(), FrameId::first() - 1);
+ OSP_DCHECK_EQ(checkpoint_frame(), FrameId::leader());
OSP_CHECK_GT(rtcp_buffer_capacity_, 0);
OSP_CHECK(rtcp_buffer_);
- rtcp_builder_.SetPlayoutDelay(initial_target_playout_delay);
- playout_delay_changes_.emplace_back(FrameId::first() - 1,
- initial_target_playout_delay);
+ rtcp_builder_.SetPlayoutDelay(config.target_playout_delay);
+ playout_delay_changes_.emplace_back(FrameId::leader(),
+ config.target_playout_delay);
packet_router_->OnReceiverCreated(rtcp_session_.sender_ssrc(), this);
}
@@ -355,8 +352,7 @@ void Receiver::SendRtcp() {
// When there are no incomplete frames, use a longer "keepalive" interval.
const Clock::duration interval =
(no_nacks ? kRtcpReportInterval : kNackFeedbackInterval);
- rtcp_alarm_.Schedule(std::bind(&Receiver::SendRtcp, this),
- last_rtcp_send_time_ + interval);
+ rtcp_alarm_.Schedule([this] { SendRtcp(); }, last_rtcp_send_time_ + interval);
}
const Receiver::PendingFrame& Receiver::GetQueueEntry(FrameId frame_id) const {
@@ -389,7 +385,7 @@ void Receiver::RecordNewTargetPlayoutDelay(FrameId as_of_frame,
[&](const auto& entry) { return entry.first > as_of_frame; });
playout_delay_changes_.emplace(insert_it, as_of_frame, delay);
- OSP_DCHECK(openscreen::AreElementsSortedAndUnique(playout_delay_changes_));
+ OSP_DCHECK(AreElementsSortedAndUnique(playout_delay_changes_));
}
milliseconds Receiver::ResolveTargetPlayoutDelay(FrameId frame_id) const {
@@ -481,5 +477,5 @@ constexpr milliseconds Receiver::kDefaultPlayerProcessingTime;
constexpr int Receiver::kNoFramesReady;
constexpr milliseconds Receiver::kNackFeedbackInterval;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.h b/chromium/third_party/openscreen/src/cast/streaming/receiver.h
index d1ea8b0c08a..e63a9e4ee1f 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.h
@@ -24,16 +24,16 @@
#include "cast/streaming/rtcp_session.h"
#include "cast/streaming/rtp_packet_parser.h"
#include "cast/streaming/sender_report_parser.h"
-#include "cast/streaming/session_config.h"
#include "cast/streaming/ssrc.h"
#include "platform/api/time.h"
#include "util/alarm.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
struct EncodedFrame;
class ReceiverPacketRouter;
+struct SessionConfig;
// The Cast Streaming Receiver, a peer corresponding to some Cast Streaming
// Sender at the other end of a network link.
@@ -56,7 +56,7 @@ class ReceiverPacketRouter;
// decoder, and the resulting decoded media is played out. Also, here is a
// general usage example:
//
-// class MyPlayer : public cast::streaming::Receiver::Consumer {
+// class MyPlayer : public openscreen::cast::Receiver::Consumer {
// public:
// explicit MyPlayer(Receiver* receiver) : receiver_(receiver) {
// recevier_->SetPlayerProcessingTime(std::chrono::milliseconds(10));
@@ -72,7 +72,7 @@ class ReceiverPacketRouter;
// void OnFramesReady(int next_frame_buffer_size) override {
// std::vector<uint8_t> buffer;
// buffer.resize(next_frame_buffer_size);
-// cast::streaming::EncodedFrame encoded_frame =
+// openscreen::cast::EncodedFrame encoded_frame =
// receiver_->ConsumeNextFrame(absl::Span<uint8_t>(buffer));
//
// display_.RenderFrame(decoder_.DecodeFrame(encoded_frame.data));
@@ -124,8 +124,7 @@ class Receiver {
// is started).
Receiver(Environment* environment,
ReceiverPacketRouter* packet_router,
- const cast::streaming::SessionConfig& config,
- std::chrono::milliseconds initial_target_playout_delay);
+ const SessionConfig& config);
~Receiver();
Ssrc ssrc() const { return rtcp_session_.receiver_ssrc(); }
@@ -143,8 +142,7 @@ class Receiver {
// based on changing environmental conditions.
//
// Default setting: kDefaultPlayerProcessingTime
- void SetPlayerProcessingTime(
- openscreen::platform::Clock::duration needed_time);
+ void SetPlayerProcessingTime(Clock::duration needed_time);
// Propagates a "picture loss indicator" notification to the Sender,
// requesting a key frame so that decode/playout can recover. It is safe to
@@ -185,11 +183,10 @@ class Receiver {
// Called by ReceiverPacketRouter to provide this Receiver with what looks
// like a RTP/RTCP packet meant for it specifically (among other Receivers).
- void OnReceivedRtpPacket(openscreen::platform::Clock::time_point arrival_time,
+ void OnReceivedRtpPacket(Clock::time_point arrival_time,
std::vector<uint8_t> packet);
- void OnReceivedRtcpPacket(
- openscreen::platform::Clock::time_point arrival_time,
- std::vector<uint8_t> packet);
+ void OnReceivedRtcpPacket(Clock::time_point arrival_time,
+ std::vector<uint8_t> packet);
private:
// An entry in the circular queue (see |pending_frames_|).
@@ -200,8 +197,7 @@ class Receiver {
// at the Sender. This is computed and assigned when the RTP packet with ID
// 0 is processed. Add the target playout delay to this to get the target
// playout time.
- absl::optional<openscreen::platform::Clock::time_point>
- estimated_capture_time;
+ absl::optional<Clock::time_point> estimated_capture_time;
PendingFrame();
~PendingFrame();
@@ -256,10 +252,9 @@ class Receiver {
// Sets the |consumption_alarm_| to check whether any frames are ready,
// including possibly skipping over late frames in order to make not-yet-late
// frames become ready. The default argument value means "without delay."
- void ScheduleFrameReadyCheck(
- openscreen::platform::Clock::time_point when = {});
+ void ScheduleFrameReadyCheck(Clock::time_point when = Alarm::kImmediately);
- const openscreen::platform::ClockNowFunctionPtr now_;
+ const ClockNowFunctionPtr now_;
ReceiverPacketRouter* const packet_router_;
RtcpSession rtcp_session_;
SenderReportParser rtcp_parser_;
@@ -276,9 +271,8 @@ class Receiver {
// Schedules tasks to ensure RTCP reports are sent within a bounded interval.
// Not scheduled until after this Receiver has processed the first packet from
// the Sender.
- openscreen::Alarm rtcp_alarm_;
- openscreen::platform::Clock::time_point last_rtcp_send_time_ =
- openscreen::platform::Clock::time_point::min();
+ Alarm rtcp_alarm_;
+ Clock::time_point last_rtcp_send_time_ = Clock::time_point::min();
// The last Sender Report received and when the packet containing it had
// arrived. This contains lip-sync timestamps used as part of the calculation
@@ -286,7 +280,7 @@ class Receiver {
// back to the Sender in the Receiver Reports. It is nullopt until the first
// parseable Sender Report is received.
absl::optional<SenderReportParser::SenderReportWithId> last_sender_report_;
- openscreen::platform::Clock::time_point last_sender_report_arrival_time_;
+ Clock::time_point last_sender_report_arrival_time_;
// Tracks the offset between the Receiver's [local] clock and the Sender's
// clock. This is invalid until the first Sender Report has been successfully
@@ -295,12 +289,12 @@ class Receiver {
// The ID of the latest frame whose existence is known to this Receiver. This
// value must always be greater than or equal to |checkpoint_frame()|.
- FrameId latest_frame_expected_ = FrameId::first() - 1;
+ FrameId latest_frame_expected_ = FrameId::leader();
// The ID of the last frame consumed. This value must always be less than or
// equal to |checkpoint_frame()|, since it's impossible to consume incomplete
// frames!
- FrameId last_frame_consumed_ = FrameId::first() - 1;
+ FrameId last_frame_consumed_ = FrameId::leader();
// The ID of the latest key frame known to be in-flight. This is used by
// RequestKeyFrame() to ensure the PLI condition doesn't get set again until
@@ -333,19 +327,21 @@ class Receiver {
// The additional time needed to decode/play-out each frame after being
// consumed from this Receiver.
- openscreen::platform::Clock::duration player_processing_time_ =
- kDefaultPlayerProcessingTime;
+ Clock::duration player_processing_time_ = kDefaultPlayerProcessingTime;
// Scheduled to check whether there are frames ready and, if there are, to
// notify the Consumer via OnFramesReady().
- openscreen::Alarm consumption_alarm_;
+ Alarm consumption_alarm_;
// The interval between sending ACK/NACK feedback RTCP messages while
// incomplete frames exist in the queue.
+ //
+ // TODO(miu): This should be a function of the current target playout delay,
+ // similar to the Sender's kickstart interval logic.
static constexpr std::chrono::milliseconds kNackFeedbackInterval{30};
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RECEIVER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc
index 5e3d17bd229..7a5c41aacc3 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.cc
@@ -5,17 +5,14 @@
#include "cast/streaming/receiver_packet_router.h"
#include <algorithm>
-#include <iomanip>
#include "cast/streaming/packet_util.h"
#include "cast/streaming/receiver.h"
#include "util/logging.h"
+#include "util/stringprintf.h"
-using openscreen::IPEndpoint;
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
ReceiverPacketRouter::ReceiverPacketRouter(Environment* environment)
: environment_(environment) {
@@ -78,16 +75,11 @@ void ReceiverPacketRouter::OnReceivedPacket(const IPEndpoint& source,
const std::pair<ApparentPacketType, Ssrc> seems_like =
InspectPacketForRouting(packet);
if (seems_like.first == ApparentPacketType::UNKNOWN) {
- // If the packet type is unknown, log a warning containing a hex dump.
- constexpr int kMaxDumpSize = 96;
- std::ostringstream hex_dump;
- hex_dump << std::setfill('0') << std::hex;
- for (int i = 0, len = std::min<int>(packet.size(), kMaxDumpSize); i < len;
- ++i) {
- hex_dump << std::setw(2) << static_cast<int>(packet[i]);
- }
+ constexpr int kMaxPartiaHexDumpSize = 96;
OSP_LOG_WARN << "UNKNOWN packet of " << packet.size()
- << " bytes. Partial hex dump: " << hex_dump.str();
+ << " bytes. Partial hex dump: "
+ << HexEncode(absl::Span<const uint8_t>(packet).subspan(
+ 0, kMaxPartiaHexDumpSize));
return;
}
const auto it = FindEntry(seems_like.second);
@@ -115,5 +107,5 @@ ReceiverPacketRouter::ReceiverEntries::iterator ReceiverPacketRouter::FindEntry(
});
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h
index 88db08a817a..d90e533596b 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_packet_router.h
@@ -14,8 +14,8 @@
#include "cast/streaming/environment.h"
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
class Receiver;
@@ -47,8 +47,8 @@ class ReceiverPacketRouter final : public Environment::PacketConsumer {
using ReceiverEntries = std::vector<std::pair<Ssrc, Receiver*>>;
// Environment::PacketConsumer implementation.
- void OnReceivedPacket(const openscreen::IPEndpoint& source,
- openscreen::platform::Clock::time_point arrival_time,
+ void OnReceivedPacket(const IPEndpoint& source,
+ Clock::time_point arrival_time,
std::vector<uint8_t> packet) final;
// Helper to return an iterator pointing to the entry corresponding to the
@@ -60,7 +60,7 @@ class ReceiverPacketRouter final : public Environment::PacketConsumer {
ReceiverEntries receivers_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RECEIVER_PACKET_ROUTER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc
index 4ad3a3b675c..325a44509a0 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc
@@ -4,43 +4,299 @@
#include "cast/streaming/receiver_session.h"
+#include <chrono> // NOLINT
+#include <string>
#include <utility>
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "cast/streaming/environment.h"
+#include "cast/streaming/message_port.h"
+#include "cast/streaming/message_util.h"
+#include "cast/streaming/offer_messages.h"
+#include "cast/streaming/receiver.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
-
-ReceiverSession::ConfiguredReceivers::ConfiguredReceivers(
- std::unique_ptr<Receiver> audio_receiver,
- absl::optional<SessionConfig> audio_receiver_config,
- std::unique_ptr<Receiver> video_receiver,
- absl::optional<SessionConfig> video_receiver_config)
- : audio_receiver_(std::move(audio_receiver)),
- audio_receiver_config_(std::move(audio_receiver_config)),
- video_receiver_(std::move(video_receiver)),
- video_receiver_config_(std::move(video_receiver_config)) {}
-
-ReceiverSession::ConfiguredReceivers::ConfiguredReceivers(
- ConfiguredReceivers&&) noexcept = default;
-ReceiverSession::ConfiguredReceivers& ReceiverSession::ConfiguredReceivers::
-operator=(ConfiguredReceivers&&) noexcept = default;
-ReceiverSession::ConfiguredReceivers::~ConfiguredReceivers() = default;
-
-ReceiverSession::ReceiverSession(Client* client, ReceiverPacketRouter* router)
- : client_(client), router_(router) {
+
+// JSON message field values specific to the Receiver Session.
+static constexpr char kMessageTypeOffer[] = "OFFER";
+
+// List of OFFER message fields.
+static constexpr char kOfferMessageBody[] = "offer";
+static constexpr char kKeyType[] = "type";
+static constexpr char kSequenceNumber[] = "seqNum";
+
+// Using statements for constructor readability.
+using Preferences = ReceiverSession::Preferences;
+using ConfiguredReceivers = ReceiverSession::ConfiguredReceivers;
+
+namespace {
+
+std::string GetCodecName(ReceiverSession::AudioCodec codec) {
+ switch (codec) {
+ case ReceiverSession::AudioCodec::kAac:
+ return "aac";
+ case ReceiverSession::AudioCodec::kOpus:
+ return "opus";
+ }
+
+ OSP_NOTREACHED() << "Codec not accounted for in switch statement.";
+ return {};
+}
+
+std::string GetCodecName(ReceiverSession::VideoCodec codec) {
+ switch (codec) {
+ case ReceiverSession::VideoCodec::kH264:
+ return "h264";
+ case ReceiverSession::VideoCodec::kVp8:
+ return "vp8";
+ case ReceiverSession::VideoCodec::kHevc:
+ return "hevc";
+ case ReceiverSession::VideoCodec::kVp9:
+ return "vp9";
+ }
+
+ OSP_NOTREACHED() << "Codec not accounted for in switch statement.";
+ return {};
+}
+
+template <typename Stream, typename Codec>
+const Stream* SelectStream(const std::vector<Codec>& preferred_codecs,
+ const std::vector<Stream>& offered_streams) {
+ for (Codec codec : preferred_codecs) {
+ const std::string codec_name = GetCodecName(codec);
+ for (const Stream& offered_stream : offered_streams) {
+ if (offered_stream.stream.codec_name == codec_name) {
+ OSP_VLOG << "Selected " << codec_name << " as codec for streaming.";
+ return &offered_stream;
+ }
+ }
+ }
+ return nullptr;
+}
+
+} // namespace
+
+Preferences::Preferences() = default;
+Preferences::Preferences(std::vector<VideoCodec> video_codecs,
+ std::vector<AudioCodec> audio_codecs)
+ : Preferences(video_codecs, audio_codecs, nullptr, nullptr) {}
+
+Preferences::Preferences(std::vector<VideoCodec> video_codecs,
+ std::vector<AudioCodec> audio_codecs,
+ std::unique_ptr<Constraints> constraints,
+ std::unique_ptr<DisplayDescription> description)
+ : video_codecs(std::move(video_codecs)),
+ audio_codecs(std::move(audio_codecs)),
+ constraints(std::move(constraints)),
+ display_description(std::move(description)) {}
+
+Preferences::Preferences(Preferences&&) noexcept = default;
+Preferences& Preferences::operator=(Preferences&&) noexcept = default;
+
+ReceiverSession::ReceiverSession(Client* const client,
+ Environment* environment,
+ MessagePort* message_port,
+ Preferences preferences)
+ : client_(client),
+ environment_(environment),
+ message_port_(message_port),
+ preferences_(std::move(preferences)),
+ packet_router_(environment_) {
OSP_DCHECK(client_);
- OSP_DCHECK(router_);
+ OSP_DCHECK(message_port_);
+ OSP_DCHECK(environment_);
+
+ message_port_->SetClient(this);
+}
+
+ReceiverSession::~ReceiverSession() {
+ ResetReceivers();
+ message_port_->SetClient(nullptr);
+}
+
+void ReceiverSession::OnMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) {
+ ErrorOr<Json::Value> message_json = json::Parse(message);
+
+ if (!message_json) {
+ client_->OnError(this, Error::Code::kJsonParseError);
+ OSP_LOG_WARN << "Received an invalid message: " << message;
+ return;
+ }
+
+ // TODO(jophba): add sender connected/disconnected messaging.
+ auto sequence_number = ParseInt(message_json.value(), kSequenceNumber);
+ if (!sequence_number) {
+ OSP_LOG_WARN << "Invalid message sequence number";
+ return;
+ }
+
+ auto key_or_error = ParseString(message_json.value(), kKeyType);
+ if (!key_or_error) {
+ OSP_LOG_WARN << "Invalid message key";
+ return;
+ }
+
+ Message parsed_message{sender_id.data(), message_namespace.data(),
+ sequence_number.value()};
+ if (key_or_error.value() == kMessageTypeOffer) {
+ parsed_message.body = std::move(message_json.value()[kOfferMessageBody]);
+ if (parsed_message.body.isNull()) {
+ OSP_LOG_WARN << "Invalid message offer body";
+ return;
+ }
+ OnOffer(&parsed_message);
+ }
+}
+
+void ReceiverSession::OnError(Error error) {
+ OSP_LOG_WARN << "ReceiverSession's MessagePump encountered an error:"
+ << error;
+}
+
+void ReceiverSession::OnOffer(Message* message) {
+ ErrorOr<Offer> offer = Offer::Parse(std::move(message->body));
+ if (!offer) {
+ client_->OnError(this, offer.error());
+ OSP_LOG_WARN << "Could not parse offer" << offer.error();
+ return;
+ }
+
+ const AudioStream* selected_audio_stream = nullptr;
+ if (!offer.value().audio_streams.empty() &&
+ !preferences_.audio_codecs.empty()) {
+ selected_audio_stream =
+ SelectStream(preferences_.audio_codecs, offer.value().audio_streams);
+ }
+
+ const VideoStream* selected_video_stream = nullptr;
+ if (!offer.value().video_streams.empty() &&
+ !preferences_.video_codecs.empty()) {
+ selected_video_stream =
+ SelectStream(preferences_.video_codecs, offer.value().video_streams);
+ }
+
+ cast_mode_ = offer.value().cast_mode;
+ auto receivers =
+ TrySpawningReceivers(selected_audio_stream, selected_video_stream);
+ if (receivers) {
+ const Answer answer =
+ ConstructAnswer(message, selected_audio_stream, selected_video_stream);
+ client_->OnNegotiated(this, std::move(receivers.value()));
+
+ message->body = answer.ToAnswerMessage();
+ } else {
+ message->body = CreateInvalidAnswer(receivers.error());
+ }
+
+ SendMessage(message);
+}
+
+std::pair<SessionConfig, std::unique_ptr<Receiver>>
+ReceiverSession::ConstructReceiver(const Stream& stream) {
+ SessionConfig config = {stream.ssrc, stream.ssrc + 1,
+ stream.rtp_timebase, stream.channels,
+ stream.target_delay, stream.aes_key,
+ stream.aes_iv_mask};
+ auto receiver =
+ std::make_unique<Receiver>(environment_, &packet_router_, config);
+
+ return std::make_pair(std::move(config), std::move(receiver));
+}
+
+ErrorOr<ConfiguredReceivers> ReceiverSession::TrySpawningReceivers(
+ const AudioStream* audio,
+ const VideoStream* video) {
+ if (!audio && !video) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ ResetReceivers();
+
+ absl::optional<ConfiguredReceiver<AudioStream>> audio_receiver;
+ absl::optional<ConfiguredReceiver<VideoStream>> video_receiver;
+
+ if (audio) {
+ auto audio_pair = ConstructReceiver(audio->stream);
+ current_audio_receiver_ = std::move(audio_pair.second);
+ audio_receiver.emplace(ConfiguredReceiver<AudioStream>{
+ current_audio_receiver_.get(), std::move(audio_pair.first), *audio});
+ }
+
+ if (video) {
+ auto video_pair = ConstructReceiver(video->stream);
+ current_video_receiver_ = std::move(video_pair.second);
+ video_receiver.emplace(ConfiguredReceiver<VideoStream>{
+ current_video_receiver_.get(), std::move(video_pair.first), *video});
+ }
+
+ return ConfiguredReceivers{std::move(audio_receiver),
+ std::move(video_receiver)};
+}
+
+void ReceiverSession::ResetReceivers() {
+ if (current_video_receiver_ || current_audio_receiver_) {
+ client_->OnConfiguredReceiversDestroyed(this);
+ current_audio_receiver_.reset();
+ current_video_receiver_.reset();
+ }
+}
+
+Answer ReceiverSession::ConstructAnswer(
+ Message* message,
+ const AudioStream* selected_audio_stream,
+ const VideoStream* selected_video_stream) {
+ OSP_DCHECK(selected_audio_stream || selected_video_stream);
+
+ std::vector<int> stream_indexes;
+ std::vector<Ssrc> stream_ssrcs;
+ if (selected_audio_stream) {
+ stream_indexes.push_back(selected_audio_stream->stream.index);
+ stream_ssrcs.push_back(selected_audio_stream->stream.ssrc + 1);
+ }
+
+ if (selected_video_stream) {
+ stream_indexes.push_back(selected_video_stream->stream.index);
+ stream_ssrcs.push_back(selected_video_stream->stream.ssrc + 1);
+ }
+
+ absl::optional<Constraints> constraints;
+ if (preferences_.constraints) {
+ constraints = *preferences_.constraints;
+ }
+
+ absl::optional<DisplayDescription> display;
+ if (preferences_.display_description) {
+ display = *preferences_.display_description;
+ }
+
+ return Answer{cast_mode_,
+ environment_->GetBoundLocalEndpoint().port,
+ std::move(stream_indexes),
+ std::move(stream_ssrcs),
+ constraints,
+ display,
+ std::vector<int>{}, // receiver_rtcp_event_log
+ std::vector<int>{}, // receiver_rtcp_dscp
+ supports_wifi_status_reporting_};
}
-ReceiverSession::ReceiverSession(ReceiverSession&&) noexcept = default;
-ReceiverSession& ReceiverSession::operator=(ReceiverSession&&) = default;
-ReceiverSession::~ReceiverSession() = default;
+void ReceiverSession::SendMessage(Message* message) {
+ // All messages have the sequence number embedded.
+ message->body[kSequenceNumber] = message->sequence_number;
-void ReceiverSession::SelectOffer(const SessionConfig& selected_offer) {
- // TODO(jophba): implement receiver session methods.
- OSP_UNIMPLEMENTED();
+ auto body_or_error = json::Stringify(message->body);
+ if (body_or_error.is_value()) {
+ message_port_->PostMessage(message->sender_id, message->message_namespace,
+ body_or_error.value());
+ } else {
+ client_->OnError(this, body_or_error.error());
+ }
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h
index 248bf8ad178..be59f4ec2c7 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h
@@ -6,73 +6,161 @@
#define CAST_STREAMING_RECEIVER_SESSION_H_
#include <memory>
+#include <string>
#include <vector>
-#include "cast/streaming/receiver.h"
+#include "cast/streaming/answer_messages.h"
+#include "cast/streaming/message_port.h"
+#include "cast/streaming/offer_messages.h"
#include "cast/streaming/receiver_packet_router.h"
#include "cast/streaming/session_config.h"
+#include "util/json/json_serialization.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
-class ReceiverSession {
+class CastSocket;
+class Environment;
+class Receiver;
+class VirtualConnectionRouter;
+class VirtualConnection;
+
+class ReceiverSession final : public MessagePort::Client {
public:
- class ConfiguredReceivers {
- public:
+ // A small helper struct that contains all of the information necessary for
+ // a configured receiver, including a receiver, its session config, and the
+ // stream selected from the OFFER message to instantiate the receiver.
+ template <typename T>
+ struct ConfiguredReceiver {
+ Receiver* receiver;
+ const SessionConfig receiver_config;
+ const T selected_stream;
+ };
+
+ // Upon successful negotiation, a set of configured receivers is constructed
+ // for handling audio and video. Note that either receiver may be null.
+ struct ConfiguredReceivers {
// In practice, we may have 0, 1, or 2 receivers configured, depending
// on if the device supports audio and video, and if we were able to
// successfully negotiate a receiver configuration.
- ConfiguredReceivers(
- std::unique_ptr<Receiver> audio_receiver,
- const absl::optional<SessionConfig> audio_receiver_config,
- std::unique_ptr<Receiver> video_receiver,
- const absl::optional<SessionConfig> video_receiver_config);
- ConfiguredReceivers(const ConfiguredReceivers&) = delete;
- ConfiguredReceivers(ConfiguredReceivers&&) noexcept;
- ConfiguredReceivers& operator=(const ConfiguredReceivers&) = delete;
- ConfiguredReceivers& operator=(ConfiguredReceivers&&) noexcept;
- ~ConfiguredReceivers();
+
+ // NOTES ON LIFETIMES: The audio and video receiver pointers are expected
+ // to be valid until the OnConfiguredReceiversDestroyed event is fired, at
+ // which point they become invalid and need to replaced by the results of
+ // the ensuing OnNegotiated call.
// If the receiver is audio- or video-only, either of the receivers
// may be nullptr. However, in the majority of cases they will be populated.
- Receiver* audio_receiver() const { return audio_receiver_.get(); }
- const absl::optional<SessionConfig>& audio_session_config() const {
- return audio_receiver_config_;
- }
- Receiver* video_receiver() const { return video_receiver_.get(); }
- const absl::optional<SessionConfig>& video_session_config() const {
- return video_receiver_config_;
- }
-
- private:
- std::unique_ptr<Receiver> audio_receiver_;
- absl::optional<SessionConfig> audio_receiver_config_;
- std::unique_ptr<Receiver> video_receiver_;
- absl::optional<SessionConfig> video_receiver_config_;
+ absl::optional<ConfiguredReceiver<AudioStream>> audio;
+ absl::optional<ConfiguredReceiver<VideoStream>> video;
};
+ // The embedder should provide a client for handling connections.
+ // When a connection is established, the OnNegotiated callback is called.
class Client {
public:
- virtual ~Client() = 0;
- virtual void OnOffer(std::vector<SessionConfig> offers) = 0;
- virtual void OnNegotiated(ConfiguredReceivers receivers) = 0;
+ // This method is called when a new set of receivers has been negotiated.
+ virtual void OnNegotiated(const ReceiverSession* session,
+ ConfiguredReceivers receivers) = 0;
+
+ // This method is called immediately preceding the invalidation of
+ // this session's receivers.
+ virtual void OnConfiguredReceiversDestroyed(
+ const ReceiverSession* session) = 0;
+
+ virtual void OnError(const ReceiverSession* session, Error error) = 0;
};
- ReceiverSession(Client* client, ReceiverPacketRouter* router);
+ // The embedder has the option of providing a list of prioritized
+ // preferences for selecting from the offer.
+ enum class AudioCodec : int { kAac, kOpus };
+ enum class VideoCodec : int { kH264, kVp8, kHevc, kVp9 };
+
+ // Note: embedders are required to implement the following
+ // codecs to be Cast V2 compliant: H264, VP8, AAC, Opus.
+ struct Preferences {
+ Preferences();
+ Preferences(std::vector<VideoCodec> video_codecs,
+ std::vector<AudioCodec> audio_codecs);
+ Preferences(std::vector<VideoCodec> video_codecs,
+ std::vector<AudioCodec> audio_codecs,
+ std::unique_ptr<Constraints> constraints,
+ std::unique_ptr<DisplayDescription> description);
+
+ Preferences(Preferences&&) noexcept;
+ Preferences(const Preferences&) = delete;
+ Preferences& operator=(Preferences&&) noexcept;
+ Preferences& operator=(const Preferences&) = delete;
+
+ std::vector<VideoCodec> video_codecs{VideoCodec::kVp8, VideoCodec::kH264};
+ std::vector<AudioCodec> audio_codecs{AudioCodec::kOpus, AudioCodec::kAac};
+
+ // The embedder has the option of directly specifying the display
+ // information and video/audio constraints that will be passed along to
+ // senders during the offer/answer exchange. If nullptr, these are ignored.
+ std::unique_ptr<Constraints> constraints;
+ std::unique_ptr<DisplayDescription> display_description;
+ };
+
+ ReceiverSession(Client* const client,
+ Environment* environment,
+ MessagePort* message_port,
+ Preferences preferences);
ReceiverSession(const ReceiverSession&) = delete;
- ReceiverSession(ReceiverSession&&) noexcept;
+ ReceiverSession(ReceiverSession&&) = delete;
ReceiverSession& operator=(const ReceiverSession&) = delete;
- ReceiverSession& operator=(ReceiverSession&&) noexcept;
+ ReceiverSession& operator=(ReceiverSession&&) = delete;
~ReceiverSession();
- void SelectOffer(const SessionConfig& selected_offer);
+ // MessagePort::Client overrides
+ void OnMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) override;
+ void OnError(Error error) override;
private:
- Client* client_;
- ReceiverPacketRouter* router_;
+ struct Message {
+ const std::string sender_id = {};
+ const std::string message_namespace = {};
+ const int sequence_number = 0;
+ Json::Value body;
+ };
+
+ // Message handlers
+ void OnOffer(Message* message);
+
+ std::pair<SessionConfig, std::unique_ptr<Receiver>> ConstructReceiver(
+ const Stream& stream);
+
+ // Either stream input to this method may be null, however if both
+ // are null this method returns error.
+ ErrorOr<ConfiguredReceivers> TrySpawningReceivers(const AudioStream* audio,
+ const VideoStream* video);
+
+ // Callers of this method should ensure at least one stream is non-null.
+ Answer ConstructAnswer(Message* message,
+ const AudioStream* audio,
+ const VideoStream* video);
+
+ void SendMessage(Message* message);
+
+ // Handles resetting receivers and notifying the client.
+ void ResetReceivers();
+
+ Client* const client_;
+ Environment* const environment_;
+ MessagePort* const message_port_;
+ const Preferences preferences_;
+
+ CastMode cast_mode_;
+ bool supports_wifi_status_reporting_ = false;
+ ReceiverPacketRouter packet_router_;
+
+ std::unique_ptr<Receiver> current_audio_receiver_;
+ std::unique_ptr<Receiver> current_video_receiver_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RECEIVER_SESSION_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc
new file mode 100644
index 00000000000..ce9a0079cd7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc
@@ -0,0 +1,536 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/receiver_session.h"
+
+#include <utility>
+
+#include "cast/streaming/mock_environment.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/base/ip_address.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::StrictMock;
+
+namespace openscreen {
+namespace cast {
+
+namespace {
+
+constexpr char kValidOfferMessage[] = R"({
+ "type": "OFFER",
+ "seqNum": 1337,
+ "offer": {
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": [
+ {
+ "index": 31337,
+ "type": "video_source",
+ "codecName": "vp9",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 127,
+ "ssrc": 19088743,
+ "maxFrameRate": "60000/1000",
+ "timeBase": "1/90000",
+ "maxBitRate": 5000000,
+ "profile": "main",
+ "level": "4",
+ "aesKey": "bbf109bf84513b456b13a184453b66ce",
+ "aesIvMask": "edaf9e4536e2b66191f560d9c04b2a69",
+ "resolutions": [
+ {
+ "width": 1280,
+ "height": 720
+ }
+ ]
+ },
+ {
+ "index": 31338,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 127,
+ "ssrc": 19088745,
+ "maxFrameRate": "60000/1000",
+ "timeBase": "1/90000",
+ "maxBitRate": 5000000,
+ "profile": "main",
+ "level": "4",
+ "aesKey": "040d756791711fd3adb939066e6d8690",
+ "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed",
+ "resolutions": [
+ {
+ "width": 1280,
+ "height": 720
+ }
+ ]
+ },
+ {
+ "index": 1337,
+ "type": "audio_source",
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 97,
+ "ssrc": 19088747,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }
+ ]
+ }
+})";
+
+constexpr char kNoAudioOfferMessage[] = R"({
+ "type": "OFFER",
+ "seqNum": 1337,
+ "offer": {
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": [
+ {
+ "index": 31338,
+ "type": "video_source",
+ "codecName": "vp8",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 127,
+ "ssrc": 19088745,
+ "maxFrameRate": "60000/1000",
+ "timeBase": "1/90000",
+ "maxBitRate": 5000000,
+ "profile": "main",
+ "level": "4",
+ "aesKey": "040d756791711fd3adb939066e6d8690",
+ "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed",
+ "resolutions": [
+ {
+ "width": 1280,
+ "height": 720
+ }
+ ]
+ }
+ ]
+ }
+})";
+
+constexpr char kNoVideoOfferMessage[] = R"({
+ "type": "OFFER",
+ "seqNum": 1337,
+ "offer": {
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": [
+ {
+ "index": 1337,
+ "type": "audio_source",
+ "codecName": "opus",
+ "rtpProfile": "cast",
+ "rtpPayloadType": 97,
+ "ssrc": 19088747,
+ "bitRate": 124000,
+ "timeBase": "1/48000",
+ "channels": 2,
+ "aesKey": "51027e4e2347cbcb49d57ef10177aebc",
+ "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1"
+ }
+ ]
+ }
+})";
+
+constexpr char kNoAudioOrVideoOfferMessage[] = R"({
+ "type": "OFFER",
+ "seqNum": 1337,
+ "offer": {
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": []
+ }
+})";
+
+constexpr char kInvalidJsonOfferMessage[] = R"({
+ "type": "OFFER",
+ "seqNum": 1337,,,
+ "offer":
+ "castMode": "mirroring",
+ "receiverGetStatus": true,
+ "supportedStreams": [
+ }
+})";
+
+class SimpleMessagePort : public MessagePort {
+ public:
+ ~SimpleMessagePort() override {}
+ void SetClient(MessagePort::Client* client) override { client_ = client; }
+
+ void ReceiveMessage(absl::string_view message) {
+ ASSERT_NE(client_, nullptr);
+ client_->OnMessage("sender-id", "namespace", message);
+ }
+
+ void ReceiveError(Error error) {
+ ASSERT_NE(client_, nullptr);
+ client_->OnError(error);
+ }
+
+ void PostMessage(absl::string_view sender_id,
+ absl::string_view message_namespace,
+ absl::string_view message) override {
+ posted_messages_.emplace_back(std::move(message));
+ }
+
+ MessagePort::Client* client() const { return client_; }
+ const std::vector<std::string> posted_messages() const {
+ return posted_messages_;
+ }
+
+ private:
+ MessagePort::Client* client_ = nullptr;
+ std::vector<std::string> posted_messages_;
+};
+
+class FakeClient : public ReceiverSession::Client {
+ public:
+ MOCK_METHOD(void,
+ OnNegotiated,
+ (const ReceiverSession*, ReceiverSession::ConfiguredReceivers),
+ (override));
+ MOCK_METHOD(void,
+ OnConfiguredReceiversDestroyed,
+ (const ReceiverSession*),
+ (override));
+ MOCK_METHOD(void, OnError, (const ReceiverSession*, Error error), (override));
+};
+
+void ExpectIsErrorAnswerMessage(const ErrorOr<Json::Value>& message_or_error) {
+ EXPECT_TRUE(message_or_error.is_value());
+ const Json::Value message = std::move(message_or_error.value());
+ EXPECT_TRUE(message["answer"].isNull());
+ EXPECT_EQ("error", message["result"].asString());
+ EXPECT_EQ(1337, message["seqNum"].asInt());
+ EXPECT_EQ("ANSWER", message["type"].asString());
+
+ const Json::Value& error = message["error"];
+ EXPECT_TRUE(error.isObject());
+ EXPECT_GT(error["code"].asInt(), 0);
+ EXPECT_EQ("", error["description"].asString());
+}
+
+} // namespace
+
+class ReceiverSessionTest : public ::testing::Test {
+ public:
+ ReceiverSessionTest() : clock_(Clock::time_point{}), task_runner_(&clock_) {}
+
+ std::unique_ptr<MockEnvironment> MakeEnvironment() {
+ auto environment = std::make_unique<NiceMock<MockEnvironment>>(
+ &FakeClock::now, &task_runner_);
+ ON_CALL(*environment, GetBoundLocalEndpoint())
+ .WillByDefault(
+ Return(IPEndpoint{IPAddress::Parse("127.0.0.1").value(), 12345}));
+ return environment;
+ }
+
+ private:
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+};
+
+TEST_F(ReceiverSessionTest, RegistersSelfOnMessagePump) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ // This should be safe, since the message_port location should not move
+ // just because of being moved into the ReceiverSession.
+ StrictMock<FakeClient> client;
+
+ auto environment = MakeEnvironment();
+ auto session = std::make_unique<ReceiverSession>(
+ &client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+ EXPECT_EQ(message_port->client(), session.get());
+}
+
+TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _))
+ .WillOnce([](const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers cr) {
+ EXPECT_TRUE(cr.audio);
+ EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u);
+ EXPECT_EQ(cr.audio.value().receiver_config.receiver_ssrc, 19088748u);
+ EXPECT_EQ(cr.audio.value().receiver_config.channels, 2);
+ EXPECT_EQ(cr.audio.value().receiver_config.rtp_timebase, 48000);
+
+ // We should have chosen opus
+ EXPECT_EQ(cr.audio.value().selected_stream.stream.index, 1337);
+ EXPECT_EQ(cr.audio.value().selected_stream.stream.type,
+ Stream::Type::kAudioSource);
+ EXPECT_EQ(cr.audio.value().selected_stream.stream.codec_name, "opus");
+ EXPECT_EQ(cr.audio.value().selected_stream.stream.channels, 2);
+
+ EXPECT_TRUE(cr.video);
+ EXPECT_EQ(cr.video.value().receiver_config.sender_ssrc, 19088745u);
+ EXPECT_EQ(cr.video.value().receiver_config.receiver_ssrc, 19088746u);
+ EXPECT_EQ(cr.video.value().receiver_config.channels, 1);
+ EXPECT_EQ(cr.video.value().receiver_config.rtp_timebase, 90000);
+
+ // We should have chosen vp8
+ EXPECT_EQ(cr.video.value().selected_stream.stream.index, 31338);
+ EXPECT_EQ(cr.video.value().selected_stream.stream.type,
+ Stream::Type::kVideoSource);
+ EXPECT_EQ(cr.video.value().selected_stream.stream.codec_name, "vp8");
+ EXPECT_EQ(cr.video.value().selected_stream.stream.channels, 1);
+ });
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1);
+
+ message_port->ReceiveMessage(kValidOfferMessage);
+
+ const auto& messages = message_port->posted_messages();
+ ASSERT_EQ(1u, messages.size());
+
+ auto message_body = json::Parse(messages[0]);
+ EXPECT_TRUE(message_body.is_value());
+ const Json::Value answer = std::move(message_body.value());
+
+ EXPECT_EQ("ANSWER", answer["type"].asString());
+ EXPECT_EQ(1337, answer["seqNum"].asInt());
+ EXPECT_EQ("ok", answer["result"].asString());
+
+ const Json::Value& answer_body = answer["answer"];
+ EXPECT_TRUE(answer_body.isObject());
+
+ // Spot check the answer body fields. We have more in depth testing
+ // of answer behavior in answer_messages_unittest, but here we can
+ // ensure that the ReceiverSession properly configured the answer.
+ EXPECT_EQ("mirroring", answer_body["castMode"].asString());
+ EXPECT_EQ(1337, answer_body["sendIndexes"][0].asInt());
+ EXPECT_EQ(31338, answer_body["sendIndexes"][1].asInt());
+ EXPECT_LT(0, answer_body["udpPort"].asInt());
+ EXPECT_GT(65535, answer_body["udpPort"].asInt());
+
+ // Get status should always be false, as we have no plans to implement it.
+ EXPECT_EQ(false, answer_body["receiverGetStatus"].asBool());
+
+ // Constraints and display should not be present with no preferences.
+ EXPECT_TRUE(answer_body["constraints"].isNull());
+ EXPECT_TRUE(answer_body["display"].isNull());
+}
+
+TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(
+ &client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9},
+ {ReceiverSession::AudioCodec::kOpus}});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _))
+ .WillOnce([](const ReceiverSession* session,
+ ReceiverSession::ConfiguredReceivers cr) {
+ EXPECT_TRUE(cr.audio);
+ EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u);
+ EXPECT_EQ(cr.audio.value().receiver_config.receiver_ssrc, 19088748u);
+ EXPECT_EQ(cr.audio.value().receiver_config.channels, 2);
+ EXPECT_EQ(cr.audio.value().receiver_config.rtp_timebase, 48000);
+
+ EXPECT_TRUE(cr.video);
+ // We should have chosen vp9
+ EXPECT_EQ(cr.video.value().receiver_config.sender_ssrc, 19088743u);
+ EXPECT_EQ(cr.video.value().receiver_config.receiver_ssrc, 19088744u);
+ EXPECT_EQ(cr.video.value().receiver_config.channels, 1);
+ EXPECT_EQ(cr.video.value().receiver_config.rtp_timebase, 90000);
+ });
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1);
+ message_port->ReceiveMessage(kValidOfferMessage);
+}
+
+TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+
+ auto constraints = std::unique_ptr<Constraints>{new Constraints{
+ AudioConstraints{1, 2, 3, 4},
+ VideoConstraints{3.14159, Dimensions{320, 240, SimpleFraction{24, 1}},
+ Dimensions{1920, 1080, SimpleFraction{144, 1}}, 3000,
+ 90000000, std::chrono::milliseconds{1000}}}};
+
+ auto display = std::unique_ptr<DisplayDescription>{new DisplayDescription{
+ Dimensions{640, 480, SimpleFraction{60, 1}}, AspectRatio{16, 9},
+ AspectRatioConstraint::kFixed}};
+
+ auto environment = MakeEnvironment();
+ ReceiverSession session(
+ &client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9},
+ {ReceiverSession::AudioCodec::kOpus},
+ std::move(constraints),
+ std::move(display)});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1);
+ message_port->ReceiveMessage(kValidOfferMessage);
+
+ const auto& messages = message_port->posted_messages();
+ EXPECT_EQ(1u, messages.size());
+
+ auto message_body = json::Parse(messages[0]);
+ ASSERT_TRUE(message_body.is_value());
+ const Json::Value answer = std::move(message_body.value());
+
+ const Json::Value& answer_body = answer["answer"];
+ ASSERT_TRUE(answer_body.isObject());
+
+ // Constraints and display should be valid with valid preferences.
+ ASSERT_FALSE(answer_body["constraints"].isNull());
+ ASSERT_FALSE(answer_body["display"].isNull());
+
+ const Json::Value& display_json = answer_body["display"];
+ EXPECT_EQ("16:9", display_json["aspectRatio"].asString());
+ EXPECT_EQ("60", display_json["dimensions"]["frameRate"].asString());
+ EXPECT_EQ(640, display_json["dimensions"]["width"].asInt());
+ EXPECT_EQ(480, display_json["dimensions"]["height"].asInt());
+ EXPECT_EQ("sender", display_json["scaling"].asString());
+
+ const Json::Value& constraints_json = answer_body["constraints"];
+ ASSERT_TRUE(constraints_json.isObject());
+
+ const Json::Value& audio = constraints_json["audio"];
+ ASSERT_TRUE(audio.isObject());
+ EXPECT_EQ(4, audio["maxBitRate"].asInt());
+ EXPECT_EQ(2, audio["maxChannels"].asInt());
+ EXPECT_EQ(0, audio["maxDelay"].asInt());
+ EXPECT_EQ(1, audio["maxSampleRate"].asInt());
+ EXPECT_EQ(3, audio["minBitRate"].asInt());
+
+ const Json::Value& video = constraints_json["video"];
+ ASSERT_TRUE(video.isObject());
+ EXPECT_EQ(90000000, video["maxBitRate"].asInt());
+ EXPECT_EQ(1000, video["maxDelay"].asInt());
+ EXPECT_EQ("144", video["maxDimensions"]["frameRate"].asString());
+ EXPECT_EQ(1920, video["maxDimensions"]["width"].asInt());
+ EXPECT_EQ(1080, video["maxDimensions"]["height"].asInt());
+ EXPECT_DOUBLE_EQ(3.14159, video["maxPixelsPerSecond"].asDouble());
+ EXPECT_EQ(3000, video["minBitRate"].asInt());
+ EXPECT_EQ("24", video["minDimensions"]["frameRate"].asString());
+ EXPECT_EQ(320, video["minDimensions"]["width"].asInt());
+ EXPECT_EQ(240, video["minDimensions"]["height"].asInt());
+}
+
+TEST_F(ReceiverSessionTest, HandlesNoValidAudioStream) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1);
+
+ message_port->ReceiveMessage(kNoAudioOfferMessage);
+ const auto& messages = message_port->posted_messages();
+ EXPECT_EQ(1u, messages.size());
+
+ auto message_body = json::Parse(messages[0]);
+ EXPECT_TRUE(message_body.is_value());
+ const Json::Value& answer_body = message_body.value()["answer"];
+ EXPECT_TRUE(answer_body.isObject());
+
+ // Should still select video stream.
+ EXPECT_EQ(1u, answer_body["sendIndexes"].size());
+ EXPECT_EQ(31338, answer_body["sendIndexes"][0].asInt());
+ EXPECT_EQ(1u, answer_body["ssrcs"].size());
+ EXPECT_EQ(19088746, answer_body["ssrcs"][0].asInt());
+}
+
+TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1);
+
+ message_port->ReceiveMessage(kNoVideoOfferMessage);
+ const auto& messages = message_port->posted_messages();
+ EXPECT_EQ(1u, messages.size());
+
+ auto message_body = json::Parse(messages[0]);
+ EXPECT_TRUE(message_body.is_value());
+ const Json::Value& answer_body = message_body.value()["answer"];
+ EXPECT_TRUE(answer_body.isObject());
+
+ // Should still select audio stream.
+ EXPECT_EQ(1u, answer_body["sendIndexes"].size());
+ EXPECT_EQ(1337, answer_body["sendIndexes"][0].asInt());
+ EXPECT_EQ(1u, answer_body["ssrcs"].size());
+ EXPECT_EQ(19088748, answer_body["ssrcs"][0].asInt());
+}
+
+TEST_F(ReceiverSessionTest, HandlesNoValidStreams) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ // We shouldn't call OnNegotiated if we failed to negotiate any streams.
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0);
+
+ message_port->ReceiveMessage(kNoAudioOrVideoOfferMessage);
+ const auto& messages = message_port->posted_messages();
+ EXPECT_EQ(1u, messages.size());
+
+ auto message_body = json::Parse(messages[0]);
+ ExpectIsErrorAnswerMessage(message_body);
+}
+
+TEST_F(ReceiverSessionTest, HandlesMalformedOffer) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ // We shouldn't call OnNegotiated if we failed to negotiate any streams.
+ // Note that unlike when we simply don't select any streams, when the offer
+ // is actually completely invalid we call OnError.
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0);
+ EXPECT_CALL(client, OnError(&session, Error(Error::Code::kJsonParseError)))
+ .Times(1);
+
+ message_port->ReceiveMessage(kInvalidJsonOfferMessage);
+}
+
+TEST_F(ReceiverSessionTest, NotifiesReceiverDestruction) {
+ auto message_port = std::make_unique<SimpleMessagePort>();
+ StrictMock<FakeClient> client;
+ auto environment = MakeEnvironment();
+ ReceiverSession session(&client, environment.get(), message_port.get(),
+ ReceiverSession::Preferences{});
+
+ EXPECT_CALL(client, OnNegotiated(&session, _)).Times(2);
+ EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(2);
+
+ message_port->ReceiveMessage(kNoAudioOfferMessage);
+ message_port->ReceiveMessage(kValidOfferMessage);
+}
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc
index 2426376aed8..826af0b0663 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc
@@ -8,6 +8,7 @@
#include <algorithm>
#include <array>
+#include <utility>
#include <vector>
#include "absl/types/span.h"
@@ -15,6 +16,7 @@
#include "cast/streaming/constants.h"
#include "cast/streaming/encoded_frame.h"
#include "cast/streaming/frame_crypto.h"
+#include "cast/streaming/mock_environment.h"
#include "cast/streaming/receiver_packet_router.h"
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtcp_session.h"
@@ -35,18 +37,6 @@
#include "platform/test/fake_task_runner.h"
#include "util/logging.h"
-using openscreen::Error;
-using openscreen::ErrorOr;
-using openscreen::IPAddress;
-using openscreen::IPEndpoint;
-using openscreen::platform::Clock;
-using openscreen::platform::ClockNowFunctionPtr;
-using openscreen::platform::FakeClock;
-using openscreen::platform::FakeTaskRunner;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::UdpPacket;
-using openscreen::platform::UdpSocket;
-
using std::chrono::duration_cast;
using std::chrono::microseconds;
using std::chrono::milliseconds;
@@ -58,8 +48,8 @@ using testing::Gt;
using testing::Invoke;
using testing::SaveArg;
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
// Receiver configuration.
@@ -80,8 +70,12 @@ constexpr milliseconds kTargetPlayoutDelayChange{800};
constexpr RtpPayloadType kRtpPayloadType = RtpPayloadType::kVideoVp8;
constexpr int kMaxRtpPacketSize = 64;
-// A simulated one-way network delay.
-constexpr auto kOneWayNetworkDelay = milliseconds(23);
+// A simulated one-way network delay, and round-trip network delay.
+constexpr auto kOneWayNetworkDelay = milliseconds(3);
+constexpr auto kRoundTripNetworkDelay = 2 * kOneWayNetworkDelay;
+static_assert(kRoundTripNetworkDelay < kTargetPlayoutDelay &&
+ kRoundTripNetworkDelay < kTargetPlayoutDelayChange,
+ "Network delay must be smaller than target playout delay.");
// An EncodedFrame for unit testing, one of a sequence of simulated frames, each
// of 10 ms duration. The first frame will be a key frame; and any later frames
@@ -180,10 +174,11 @@ class MockSender : public CompoundRtcpParser::Client {
UdpPacket packet_to_send(packet_and_report_id.first.begin(),
packet_and_report_id.first.end());
packet_to_send.set_source(sender_endpoint_);
- task_runner_->PostTask(
+ task_runner_->PostTaskWithDelay(
[receiver = receiver_, packet = std::move(packet_to_send)]() mutable {
receiver->OnRead(nullptr, ErrorOr<UdpPacket>(std::move(packet)));
- });
+ },
+ kOneWayNetworkDelay);
return packet_and_report_id.second;
}
@@ -224,10 +219,11 @@ class MockSender : public CompoundRtcpParser::Client {
rtp_packetizer_.GeneratePacket(frame_being_sent_, packet_id, buffer);
UdpPacket packet_to_send(span.begin(), span.end());
packet_to_send.set_source(sender_endpoint_);
- task_runner_->PostTask(
+ task_runner_->PostTaskWithDelay(
[receiver = receiver_, packet = std::move(packet_to_send)]() mutable {
receiver->OnRead(nullptr, ErrorOr<UdpPacket>(std::move(packet)));
- });
+ },
+ kOneWayNetworkDelay);
}
}
@@ -263,19 +259,6 @@ class MockSender : public CompoundRtcpParser::Client {
EncryptedFrame frame_being_sent_;
};
-// An Environment that can intercept all packet sends. ReceiverTest will connect
-// the SendPacket() method calls to the MockSender.
-class MockEnvironment : public Environment {
- public:
- MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner)
- : Environment(now_function, task_runner) {}
-
- ~MockEnvironment() override = default;
-
- // Used for intercepting packet sends from the implementation under test.
- MOCK_METHOD1(SendPacket, void(absl::Span<const uint8_t> packet));
-};
-
class MockConsumer : public Receiver::Consumer {
public:
MOCK_METHOD1(OnFramesReady, void(int next_frame_buffer_size));
@@ -294,14 +277,21 @@ class ReceiverTest : public testing::Test {
/* .receiver_ssrc = */ kReceiverSsrc,
/* .rtp_timebase = */ kRtpTimebase,
/* .channels = */ 2,
+ /* .target_playout_delay = */ kTargetPlayoutDelay,
/* .aes_secret_key = */ kAesKey,
- /* .aes_iv_mask = */ kCastIvMask},
- kTargetPlayoutDelay),
+ /* .aes_iv_mask = */ kCastIvMask}),
sender_(&task_runner_, &env_) {
env_.set_socket_error_handler(
[](Error error) { ASSERT_TRUE(error.ok()) << error; });
ON_CALL(env_, SendPacket(_))
- .WillByDefault(Invoke(&sender_, &MockSender::OnPacketFromReceiver));
+ .WillByDefault(Invoke([this](absl::Span<const uint8_t> packet) {
+ task_runner_.PostTaskWithDelay(
+ [sender = &sender_, copy_of_packet = std::vector<uint8_t>(
+ packet.begin(), packet.end())]() mutable {
+ sender->OnPacketFromReceiver(std::move(copy_of_packet));
+ },
+ kOneWayNetworkDelay);
+ }));
receiver_.SetConsumer(&consumer_);
}
@@ -314,6 +304,22 @@ class ReceiverTest : public testing::Test {
void AdvanceClockAndRunTasks(Clock::duration delta) { clock_.Advance(delta); }
void RunTasksUntilIdle() { task_runner_.RunTasksUntilIdle(); }
+ // Sends the initial Sender Report with lip-sync timing information to
+ // "unblock" the Receiver, and confirms the Receiver immediately replies with
+ // a corresponding Receiver Report.
+ void ExchangeInitialReportPackets() {
+ const Clock::time_point start_time = FakeClock::now();
+ sender_.SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
+ AdvanceClockAndRunTasks(
+ kOneWayNetworkDelay); // Transmit report to Receiver.
+ // The Receiver will immediately reply with a Receiver Report.
+ EXPECT_CALL(sender_,
+ OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay))
+ .Times(1);
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay); // Transmit reply to Sender.
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ }
+
// Consume one frame from the Receiver, and verify that it is the same as the
// |sent_frame|. Exception: The |reference_time| is the playout time on the
// Receiver's end, while it refers to the capture time on the Sender's end.
@@ -371,7 +377,7 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) {
EXPECT_CALL(*sender(), OnReceiverReport(_))
.WillOnce(SaveArg<0>(&receiver_report));
EXPECT_CALL(*sender(),
- OnReceiverCheckpoint(FrameId::first() - 1, kTargetPlayoutDelay))
+ OnReceiverCheckpoint(FrameId::leader(), kTargetPlayoutDelay))
.Times(1);
// Have the MockSender send a Sender Report with lip-sync timing information.
@@ -380,7 +386,8 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) {
RtpTimeTicks::FromTimeSinceOrigin(seconds(1), kRtpTimebase);
const StatusReportId sender_report_id =
sender()->SendSenderReport(sender_reference_time, sender_rtp_timestamp);
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
// Expect the MockSender got back a Receiver Report that includes its SSRC and
// the last Sender Report ID.
@@ -389,7 +396,7 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) {
EXPECT_EQ(sender_report_id, receiver_report.last_status_report_id);
// Confirm the clock offset math: Since the Receiver and MockSender share the
- // same underlying FakeClock, the Receiver should be 10ms ahead of the Sender,
+ // same underlying FakeClock, the Receiver should be ahead of the Sender,
// which reflects the simulated one-way network packet travel time (of the
// Sender Report).
//
@@ -422,11 +429,8 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) {
// out of order, but such that each frame is completely received in-order. Also,
// confirms that target playout delay changes are processed/applied correctly.
TEST_F(ReceiverTest, ReceivesFramesInOrder) {
- // Send the initial Sender Report with lip-sync timing information to
- // "unblock" the Receiver.
const Clock::time_point start_time = FakeClock::now();
- sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+ ExchangeInitialReportPackets();
EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(10);
for (int i = 0; i <= 9; ++i) {
@@ -442,10 +446,15 @@ TEST_F(ReceiverTest, ReceivesFramesInOrder) {
const int permutation = (i % 2) ? i : 0;
sender()->SendRtpPackets(sender()->GetAllPacketIds(permutation));
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
+
// The Receiver should immediately ACK once it has received all the RTP
// packets to complete the frame.
- RunTasksUntilIdle();
testing::Mock::VerifyAndClearExpectations(sender());
+
+ // Advance to next frame transmission time.
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration -
+ kRoundTripNetworkDelay);
}
// When the Receiver has all of the frames and they are complete, it should
@@ -466,11 +475,8 @@ TEST_F(ReceiverTest, ReceivesFramesInOrder) {
// order, and issues the appropriate ACK/NACK feedback to the Sender as it
// realizes what it has and what it's missing.
TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
- // Send the initial Sender Report with lip-sync timing information to
- // "unblock" the Receiver.
const Clock::time_point start_time = FakeClock::now();
- sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+ ExchangeInitialReportPackets();
constexpr static int kOutOfOrderFrames[] = {3, 4, 2, 0, 1};
for (int i : kOutOfOrderFrames) {
@@ -479,7 +485,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
case 3: {
// Note that frame 4 will not yet be known to the Receiver, and so it
// should not be mentioned in any of the feedback for this case.
- EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1,
+ EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(),
kTargetPlayoutDelay))
.Times(AtLeast(1));
EXPECT_CALL(
@@ -498,7 +504,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
}
case 4: {
- EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1,
+ EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(),
kTargetPlayoutDelay))
.Times(AtLeast(1));
EXPECT_CALL(*sender(),
@@ -517,7 +523,7 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
}
case 2: {
- EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::first() - 1,
+ EXPECT_CALL(*sender(), OnReceiverCheckpoint(FrameId::leader(),
kTargetPlayoutDelay))
.Times(AtLeast(1));
EXPECT_CALL(*sender(), OnReceiverHasFrames(std::vector<FrameId>(
@@ -567,11 +573,13 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
sender()->SetFrameBeingSent(SimulatedFrame(start_time, i));
sender()->SendRtpPackets(sender()->GetAllPacketIds(i));
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
+
// While there are known incomplete frames, the Receiver should send RTCP
// packets more frequently than the default "ping" interval. Thus, advancing
// the clock by this much should result in several feedback reports
// transmitted to the Sender.
- AdvanceClockAndRunTasks(kRtcpReportInterval);
+ AdvanceClockAndRunTasks(kRtcpReportInterval - kRoundTripNetworkDelay);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
@@ -585,11 +593,8 @@ TEST_F(ReceiverTest, ReceivesFramesOutOfOrder) {
// by sending a Picture Loss Indicator (PLI) to the Sender, and then will
// automatically stop sending the PLI once a key frame has been received.
TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
- // Send the initial Sender Report with lip-sync timing information to
- // "unblock" the Receiver.
const Clock::time_point start_time = FakeClock::now();
- sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+ ExchangeInitialReportPackets();
// Send and Receive three frames in-order, normally.
for (int i = 0; i <= 2; ++i) {
@@ -599,9 +604,12 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
.Times(1);
sender()->SetFrameBeingSent(SimulatedFrame(start_time, i));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
+ // Advance to next frame transmission time.
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration -
+ kRoundTripNetworkDelay);
}
ConsumeAndVerifyFrames(0, 2, start_time);
@@ -609,7 +617,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
// decoder failure). Ensure the Sender is immediately notified.
EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1);
receiver()->RequestKeyFrame();
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay); // Propagate request to Sender.
testing::Mock::VerifyAndClearExpectations(sender());
// The Sender sends another frame that is not a key frame and, upon receipt,
@@ -618,10 +626,10 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
EXPECT_CALL(*sender(),
OnReceiverCheckpoint(FrameId::first() + 3, kTargetPlayoutDelay))
.Times(1);
- EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1);
+ EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(AtLeast(1));
sender()->SetFrameBeingSent(SimulatedFrame(start_time, 3));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration - kOneWayNetworkDelay);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
ConsumeAndVerifyFrames(3, 3, start_time);
@@ -639,7 +647,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
key_frame.referenced_frame_id = key_frame.frame_id;
sender()->SetFrameBeingSent(key_frame);
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
@@ -647,7 +655,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
// RequestKeyFrame() should not set the PLI condition again.
EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(0);
receiver()->RequestKeyFrame();
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay);
testing::Mock::VerifyAndClearExpectations(sender());
// After consuming the requested key frame, the client should be able to set
@@ -655,7 +663,7 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
ConsumeAndVerifyFrame(key_frame);
EXPECT_CALL(*sender(), OnReceiverIndicatesPictureLoss()).Times(1);
receiver()->RequestKeyFrame();
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay);
testing::Mock::VerifyAndClearExpectations(sender());
}
@@ -663,11 +671,8 @@ TEST_F(ReceiverTest, RequestsKeyFrameToRectifyPictureLoss) {
// full (i.e., when the consumer is not pulling them out of the queue). Since
// the Receiver will stop ACK'ing frames, the Sender will become stalled.
TEST_F(ReceiverTest, EatsItsFill) {
- // Send the initial Sender Report with lip-sync timing information to
- // "unblock" the Receiver.
const Clock::time_point start_time = FakeClock::now();
- sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+ ExchangeInitialReportPackets();
// Send and Receive the maximum possible number of frames in-order, normally.
for (int i = 0; i < kMaxUnackedFrames; ++i) {
@@ -678,7 +683,7 @@ TEST_F(ReceiverTest, EatsItsFill) {
.Times(1);
sender()->SetFrameBeingSent(SimulatedFrame(start_time, i));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
}
@@ -692,11 +697,11 @@ TEST_F(ReceiverTest, EatsItsFill) {
EXPECT_CALL(*sender(),
OnReceiverCheckpoint(FrameId::first() + (ignored_frame - 1),
kTargetPlayoutDelayChange))
- .Times(AtLeast(1));
+ .Times(AtLeast(0));
EXPECT_CALL(*sender(), OnReceiverIsMissingPackets(_)).Times(0);
sender()->SetFrameBeingSent(SimulatedFrame(start_time, ignored_frame));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- AdvanceClockAndRunTasks(kRtcpReportInterval);
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
}
@@ -706,7 +711,7 @@ TEST_F(ReceiverTest, EatsItsFill) {
ConsumeAndVerifyFrames(0, 0, start_time);
int no_longer_ignored_frame = ignored_frame;
++ignored_frame;
- EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(1);
+ EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(AtLeast(1));
EXPECT_CALL(*sender(),
OnReceiverCheckpoint(FrameId::first() + no_longer_ignored_frame,
kTargetPlayoutDelayChange))
@@ -716,10 +721,11 @@ TEST_F(ReceiverTest, EatsItsFill) {
sender()->SetFrameBeingSent(
SimulatedFrame(start_time, no_longer_ignored_frame));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
// This second frame should be ignored, however.
sender()->SetFrameBeingSent(SimulatedFrame(start_time, ignored_frame));
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
- AdvanceClockAndRunTasks(kRtcpReportInterval);
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
testing::Mock::VerifyAndClearExpectations(sender());
testing::Mock::VerifyAndClearExpectations(consumer());
}
@@ -728,11 +734,8 @@ TEST_F(ReceiverTest, EatsItsFill) {
// but only as inter-frame data dependency requirements permit, and only if no
// target playout delay change information would have been missed.
TEST_F(ReceiverTest, DropsLateFrames) {
- // Send the initial Sender Report with lip-sync timing information to
- // "unblock" the Receiver.
const Clock::time_point start_time = FakeClock::now();
- sender()->SendSenderReport(start_time, SimulatedFrame::GetRtpStartTime());
- AdvanceClockAndRunTasks(kOneWayNetworkDelay);
+ ExchangeInitialReportPackets();
// Before any packets have been sent/received, the Receiver should indicate no
// frames are ready.
@@ -766,8 +769,8 @@ TEST_F(ReceiverTest, DropsLateFrames) {
// is not exercising the logic meaningfully.
ASSERT_LE(size_t{3}, sender()->GetAllPacketIds(0).size());
sender()->SendRtpPackets({FramePacketId{1}});
+ AdvanceClockAndRunTasks(SimulatedFrame::kFrameDuration);
}
- RunTasksUntilIdle();
testing::Mock::VerifyAndClearExpectations(consumer());
testing::Mock::VerifyAndClearExpectations(sender());
EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame());
@@ -782,7 +785,7 @@ TEST_F(ReceiverTest, DropsLateFrames) {
sender()->SetFrameBeingSent(frames[i]);
sender()->SendRtpPackets(sender()->GetAllPacketIds(0));
}
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
testing::Mock::VerifyAndClearExpectations(consumer());
testing::Mock::VerifyAndClearExpectations(sender());
EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame());
@@ -800,7 +803,7 @@ TEST_F(ReceiverTest, DropsLateFrames) {
sender()->SetFrameBeingSent(frames[i]);
sender()->SendRtpPackets({FramePacketId{0}});
}
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
testing::Mock::VerifyAndClearExpectations(consumer());
testing::Mock::VerifyAndClearExpectations(sender());
EXPECT_EQ(Receiver::kNoFramesReady, receiver()->AdvanceToNextFrame());
@@ -816,7 +819,7 @@ TEST_F(ReceiverTest, DropsLateFrames) {
.Times(1);
sender()->SetFrameBeingSent(frames[5]);
sender()->SendRtpPackets({FramePacketId{0}});
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kRoundTripNetworkDelay);
// Note: Consuming Frame 6 will trigger the checkpoint advancement, since the
// call to AdvanceToNextFrame() contains the frame skipping/dropping logic.
ConsumeAndVerifyFrame(frames[6]);
@@ -826,7 +829,7 @@ TEST_F(ReceiverTest, DropsLateFrames) {
// After consuming Frame 6, the Receiver knows Frame 7 is also available and
// should have scheduled an immediate task to notify the Consumer of this.
EXPECT_CALL(*consumer(), OnFramesReady(Gt(0))).Times(1);
- RunTasksUntilIdle();
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay);
testing::Mock::VerifyAndClearExpectations(consumer());
// Now consume Frame 7. This shouldn't trigger any further checkpoint
@@ -834,10 +837,11 @@ TEST_F(ReceiverTest, DropsLateFrames) {
EXPECT_CALL(*consumer(), OnFramesReady(_)).Times(0);
EXPECT_CALL(*sender(), OnReceiverCheckpoint(_, _)).Times(0);
ConsumeAndVerifyFrame(frames[7]);
+ AdvanceClockAndRunTasks(kOneWayNetworkDelay);
testing::Mock::VerifyAndClearExpectations(consumer());
testing::Mock::VerifyAndClearExpectations(sender());
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc
index 30d587a24ed..696849bb501 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc
@@ -9,11 +9,8 @@
#include "cast/streaming/packet_util.h"
#include "util/saturate_cast.h"
-using openscreen::saturate_cast;
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
RtcpCommonHeader::RtcpCommonHeader() = default;
RtcpCommonHeader::~RtcpCommonHeader() = default;
@@ -30,6 +27,9 @@ void RtcpCommonHeader::AppendFields(absl::Span<uint8_t>* buffer) const {
FieldBitmask<int>(kRtcpReportCountFieldNumBits));
byte0 |= with.report_count;
break;
+ case RtcpPacketType::kSourceDescription:
+ OSP_UNIMPLEMENTED();
+ break;
case RtcpPacketType::kApplicationDefined:
case RtcpPacketType::kPayloadSpecific:
switch (with.subtype) {
@@ -242,5 +242,5 @@ absl::optional<RtcpReportBlock> RtcpReportBlock::ParseOne(
RtcpSenderReport::RtcpSenderReport() = default;
RtcpSenderReport::~RtcpSenderReport() = default;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h
index 3b6959c0e67..25e2c2ed7e5 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.h
@@ -18,8 +18,8 @@
#include "cast/streaming/rtp_time.h"
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
struct RtcpCommonHeader {
RtcpCommonHeader();
@@ -116,8 +116,7 @@ struct RtcpReportBlock {
// Convenience helper to convert the given |local_clock_delay| to the
// RtcpReportBlock::Delay timebase, then clamp and assign it to
// |delay_since_last_report|.
- void SetDelaySinceLastReport(
- openscreen::platform::Clock::duration local_clock_delay);
+ void SetDelaySinceLastReport(Clock::duration local_clock_delay);
// Serializes this report block in the first |kRtcpReportBlockSize| bytes of
// the given |buffer| and adjusts |buffer| to point to the first byte after
@@ -139,7 +138,7 @@ struct RtcpSenderReport {
// common reference clock shared by all RTP streams; 2) the RTP timestamp on
// the media capture/playout timeline. Together, these are used by a Receiver
// to achieve A/V synchronization across RTP streams for playout.
- openscreen::platform::Clock::time_point reference_time{};
+ Clock::time_point reference_time{};
RtpTimeTicks rtp_timestamp;
// The total number of RTP packets transmitted since the start of the session
@@ -178,7 +177,7 @@ struct PacketNack {
}
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTCP_COMMON_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc
index 8d088c2c972..d593e4f02ee 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc
@@ -11,10 +11,8 @@
#include "gtest/gtest.h"
#include "platform/api/time.h"
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
template <typename T>
@@ -308,5 +306,5 @@ TEST(RtcpCommonTest, ComputesDelayForReportBlocks) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc
index a8b8ab6a0ef..c9bf34ffbbf 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.cc
@@ -6,12 +6,12 @@
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
RtcpSession::RtcpSession(Ssrc sender_ssrc,
Ssrc receiver_ssrc,
- openscreen::platform::Clock::time_point start_time)
+ Clock::time_point start_time)
: sender_ssrc_(sender_ssrc),
receiver_ssrc_(receiver_ssrc),
ntp_converter_(start_time) {
@@ -22,5 +22,5 @@ RtcpSession::RtcpSession(Ssrc sender_ssrc,
RtcpSession::~RtcpSession() = default;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h
index d896cb946bd..0c0b6a45364 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_session.h
@@ -8,8 +8,8 @@
#include "cast/streaming/ntp_time.h"
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Session-level configuration and shared components for the RTCP messaging
// associated with a single Cast RTP stream. Multiple packet serialization and
@@ -21,7 +21,7 @@ class RtcpSession {
// world" wall time.
RtcpSession(Ssrc sender_ssrc,
Ssrc receiver_ssrc,
- openscreen::platform::Clock::time_point start_time);
+ Clock::time_point start_time);
~RtcpSession();
Ssrc sender_ssrc() const { return sender_ssrc_; }
@@ -36,7 +36,7 @@ class RtcpSession {
NtpTimeConverter ntp_converter_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTCP_SESSION_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc
index 0c3591299af..215b9cfcec0 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.cc
@@ -4,8 +4,8 @@
#include "cast/streaming/rtp_defines.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
bool IsRtpPayloadType(uint8_t raw_byte) {
switch (static_cast<RtpPayloadType>(raw_byte)) {
@@ -31,6 +31,7 @@ bool IsRtcpPacketType(uint8_t raw_byte) {
switch (static_cast<RtcpPacketType>(raw_byte)) {
case RtcpPacketType::kSenderReport:
case RtcpPacketType::kReceiverReport:
+ case RtcpPacketType::kSourceDescription:
case RtcpPacketType::kApplicationDefined:
case RtcpPacketType::kPayloadSpecific:
case RtcpPacketType::kExtendedReports:
@@ -42,5 +43,5 @@ bool IsRtcpPacketType(uint8_t raw_byte) {
return false;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h
index 94294cd6394..335c06a04d1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_defines.h
@@ -7,8 +7,8 @@
#include <stdint.h>
+namespace openscreen {
namespace cast {
-namespace streaming {
// Note: Cast Streaming uses a subset of the messages in the RTP/RTCP
// specification, but also adds some of its own extensions. See:
@@ -90,6 +90,7 @@ enum class RtpPayloadType : uint8_t {
kVideoVp8 = 100,
kVideoH264 = 101,
kVideoVarious = 102, // Codec being used is not fixed.
+ kVideoLast = 102,
// Some AndroidTV receivers require the payload type for audio to be 127, and
// video to be 96; regardless of the codecs actually being used. This is
@@ -141,6 +142,7 @@ enum class RtcpPacketType : uint8_t {
kSenderReport = 200,
kReceiverReport = 201,
+ kSourceDescription = 202,
kApplicationDefined = 204,
kPayloadSpecific = 206,
kExtendedReports = 207,
@@ -338,7 +340,7 @@ constexpr int kRtcpReceiverReferenceTimeReportBlockSize = 8;
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
constexpr int kRtcpPictureLossIndicatorHeaderSize = 8;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTP_DEFINES_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc
index 5fbf005abeb..b7a838e3923 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.cc
@@ -10,10 +10,8 @@
#include "cast/streaming/packet_util.h"
#include "util/logging.h"
-using openscreen::ReadBigEndian;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
RtpPacketParser::RtpPacketParser(Ssrc sender_ssrc)
: sender_ssrc_(sender_ssrc), highest_rtp_frame_id_(FrameId::first()) {}
@@ -114,5 +112,5 @@ absl::optional<RtpPacketParser::ParseResult> RtpPacketParser::Parse(
RtpPacketParser::ParseResult::ParseResult() = default;
RtpPacketParser::ParseResult::~ParseResult() = default;
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h
index b2be4c5394b..b8ce126f42d 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser.h
@@ -14,8 +14,8 @@
#include "cast/streaming/rtp_time.h"
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Parses RTP packets for all frames in the same Cast RTP stream. One
// RtpPacketParser instance should be used for all RTP packets having the same
@@ -72,7 +72,7 @@ class RtpPacketParser {
FrameId highest_rtp_frame_id_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTP_PACKET_PARSER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc
index 24169ce7a7a..73bde613e78 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_fuzzer.cc
@@ -7,8 +7,8 @@
#include "cast/streaming/rtp_packet_parser.h"
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
- using cast::streaming::RtpPacketParser;
- using cast::streaming::Ssrc;
+ using openscreen::cast::RtpPacketParser;
+ using openscreen::cast::Ssrc;
constexpr Ssrc kSenderSsrcInSeedCorpus = 0x01020304;
RtpPacketParser parser(kSenderSsrcInSeedCorpus);
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc
index 927521c001c..5f9db875560 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packet_parser_unittest.cc
@@ -8,11 +8,8 @@
#include "gtest/gtest.h"
#include "util/big_endian.h"
-using openscreen::ReadBigEndian;
-using openscreen::WriteBigEndian;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
// Tests that a simple packet for a key frame can be parsed.
@@ -309,5 +306,5 @@ TEST(RtpPacketParserTest, RejectsPacketWithBadFramePacketIds) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc
index 675a4a97cc5..610791fcb43 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.cc
@@ -14,10 +14,8 @@
#include "util/integer_division.h"
#include "util/logging.h"
-using openscreen::platform::Clock;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -124,7 +122,7 @@ absl::Span<uint8_t> RtpPacketizer::GeneratePacket(const EncryptedFrame& frame,
int RtpPacketizer::ComputeNumberOfPackets(const EncryptedFrame& frame) const {
// The total number of packets is computed by assuming the payload will be
// split-up across as few packets as possible.
- int num_packets = openscreen::DividePositivesRoundingUp(
+ int num_packets = DividePositivesRoundingUp(
static_cast<int>(frame.data.size()), max_payload_size());
// Edge case: There must always be at least one packet, even when there are no
// payload bytes. Some audio codecs, for example, use zero bytes to represent
@@ -135,5 +133,5 @@ int RtpPacketizer::ComputeNumberOfPackets(const EncryptedFrame& frame) const {
return num_packets <= int{kMaxAllowedFramePacketId} ? num_packets : -1;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h
index a3811948b5f..7e4f59b77bc 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer.h
@@ -12,8 +12,8 @@
#include "cast/streaming/rtp_defines.h"
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Transforms a logical sequence of EncryptedFrames into RTP packets for
// transmission. A single instance of RtpPacketizer should be used for all the
@@ -47,6 +47,15 @@ class RtpPacketizer {
// packetized.
int ComputeNumberOfPackets(const EncryptedFrame& frame) const;
+ // See rtp_defines.h for wire-format diagram.
+ static constexpr int kBaseRtpHeaderSize =
+ // Plus one byte, because this implementation always includes the 8-bit
+ // Reference Frame ID field.
+ kRtpPacketMinValidSize + 1;
+ static constexpr int kAdaptiveLatencyHeaderSize = 4;
+ static constexpr int kMaxRtpHeaderSize =
+ kBaseRtpHeaderSize + kAdaptiveLatencyHeaderSize;
+
private:
int max_payload_size() const {
// Start with the configured max packet size, then subtract reserved space
@@ -64,18 +73,9 @@ class RtpPacketizer {
// re-transmitted, must have different sequence numbers (within wrap-around
// concerns) per the RTP spec.
uint16_t sequence_number_;
-
- // See rtp_defines.h for wire-format diagram.
- static constexpr int kBaseRtpHeaderSize =
- // Plus one byte, because this implementation always includes the 8-bit
- // Reference Frame ID field.
- kRtpPacketMinValidSize + 1;
- static constexpr int kAdaptiveLatencyHeaderSize = 4;
- static constexpr int kMaxRtpHeaderSize =
- kBaseRtpHeaderSize + kAdaptiveLatencyHeaderSize;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTP_PACKETIZER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc
index 52b460de2dd..d99eefa011e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc
@@ -11,8 +11,8 @@
#include "cast/streaming/ssrc.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
constexpr RtpPayloadType kPayloadType = RtpPayloadType::kAudioOpus;
@@ -42,7 +42,7 @@ class RtpPacketizerTest : public testing::Test {
frame.frame_id = frame_id;
frame.referenced_frame_id = is_key_frame ? frame_id : (frame_id - 1);
frame.rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(987);
- frame.reference_time = openscreen::platform::Clock::now();
+ frame.reference_time = Clock::now();
frame.new_playout_delay = new_playout_delay;
std::unique_ptr<uint8_t[]> buffer(new uint8_t[payload_size]);
@@ -203,5 +203,5 @@ TEST_F(RtpPacketizerTest, GeneratesPacketForRetransmission) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc
index 3f73715eb57..ee6078d45aa 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.cc
@@ -6,8 +6,8 @@
#include <sstream>
+namespace openscreen {
namespace cast {
-namespace streaming {
std::ostream& operator<<(std::ostream& out, const RtpTimeDelta rhs) {
if (rhs.value_ >= 0)
@@ -21,5 +21,5 @@ std::ostream& operator<<(std::ostream& out, const RtpTimeTicks rhs) {
return out << "RTP@" << rhs.value_;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h
index 0a7f5d86458..536e7dd3ed1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time.h
@@ -11,12 +11,14 @@
#include <cmath>
#include <limits>
#include <sstream>
+#include <type_traits>
#include "cast/streaming/expanded_value_base.h"
#include "platform/api/time.h"
+#include "util/saturate_cast.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Forward declarations (see below).
class RtpTimeDelta;
@@ -155,25 +157,15 @@ class RtpTimeDelta : public ExpandedValueBase<int64_t, RtpTimeDelta> {
constexpr int64_t value() const { return value_; }
template <typename Rep>
- static Rep ToNearestRepresentativeValue(double ticks) {
- if (ticks <= std::numeric_limits<Rep>::min()) {
- return std::numeric_limits<Rep>::min();
- } else if (ticks >= std::numeric_limits<Rep>::max()) {
- return std::numeric_limits<Rep>::max();
- }
+ static std::enable_if_t<std::is_floating_point<Rep>::value, Rep>
+ ToNearestRepresentativeValue(double ticks) {
+ return Rep(ticks);
+ }
- static_assert(
- std::is_floating_point<Rep>::value ||
- (std::is_integral<Rep>::value &&
- sizeof(Rep) <= sizeof(decltype(llround(ticks)))),
- "Rep must be an integer (<= 64 bits) or a floating-point type.");
- if (std::is_floating_point<Rep>::value) {
- return Rep(ticks);
- }
- if (sizeof(Rep) <= sizeof(decltype(lround(ticks)))) {
- return Rep(lround(ticks));
- }
- return Rep(llround(ticks));
+ template <typename Rep>
+ static std::enable_if_t<std::is_integral<Rep>::value, Rep>
+ ToNearestRepresentativeValue(double ticks) {
+ return rounded_saturate_cast<Rep>(ticks);
}
};
@@ -262,7 +254,7 @@ class RtpTimeTicks : public ExpandedValueBase<int64_t, RtpTimeTicks> {
constexpr int64_t value() const { return value_; }
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_RTP_TIME_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc
index 44312e7bfaa..8427a5ea201 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc
@@ -6,8 +6,8 @@
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Tests that conversions between std::chrono durations and RtpTimeDelta are
// accurate. Note that this implicitly tests the conversions to/from
@@ -71,5 +71,5 @@ TEST(RtpTimeDeltaTest, ConversionToAndFromDurations) {
.ToDuration<microseconds>(kTimebase));
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.cc b/chromium/third_party/openscreen/src/cast/streaming/sender.cc
new file mode 100644
index 00000000000..fa3268a9f68
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender.cc
@@ -0,0 +1,541 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/sender.h"
+
+#include <algorithm>
+#include <ratio> // NOLINT
+
+#include "cast/streaming/session_config.h"
+#include "util/logging.h"
+#include "util/std_util.h"
+
+namespace openscreen {
+namespace cast {
+
+using std::chrono::duration_cast;
+using std::chrono::microseconds;
+using std::chrono::milliseconds;
+
+using openscreen::operator<<; // For std::chrono::duration logging.
+
+Sender::Sender(Environment* environment,
+ SenderPacketRouter* packet_router,
+ const SessionConfig& config,
+ RtpPayloadType rtp_payload_type)
+ : packet_router_(packet_router),
+ rtcp_session_(config.sender_ssrc,
+ config.receiver_ssrc,
+ environment->now()),
+ rtcp_parser_(&rtcp_session_, this),
+ sender_report_builder_(&rtcp_session_),
+ rtp_packetizer_(rtp_payload_type,
+ config.sender_ssrc,
+ packet_router_->max_packet_size()),
+ rtp_timebase_(config.rtp_timebase),
+ crypto_(config.aes_secret_key, config.aes_iv_mask),
+ target_playout_delay_(config.target_playout_delay) {
+ OSP_DCHECK(packet_router_);
+ OSP_DCHECK_NE(rtcp_session_.sender_ssrc(), rtcp_session_.receiver_ssrc());
+ OSP_DCHECK_GT(rtp_timebase_, 0);
+ OSP_DCHECK(target_playout_delay_ > milliseconds::zero());
+
+ pending_sender_report_.reference_time = SenderPacketRouter::kNever;
+
+ packet_router_->OnSenderCreated(rtcp_session_.receiver_ssrc(), this);
+}
+
+Sender::~Sender() {
+ packet_router_->OnSenderDestroyed(rtcp_session_.receiver_ssrc());
+}
+
+void Sender::SetObserver(Sender::Observer* observer) {
+ observer_ = observer;
+}
+
+int Sender::GetInFlightFrameCount() const {
+ return num_frames_in_flight_;
+}
+
+Clock::duration Sender::GetInFlightMediaDuration(
+ RtpTimeTicks next_frame_rtp_timestamp) const {
+ if (num_frames_in_flight_ == 0) {
+ return Clock::duration::zero(); // No frames are currently in-flight.
+ }
+
+ const PendingFrameSlot& oldest_slot = *get_slot_for(checkpoint_frame_id_ + 1);
+ // Note: The oldest slot's frame cannot have been canceled because the
+ // protocol does not allow ACK'ing this particular frame without also moving
+ // the checkpoint forward. See "CST2 feedback" discussion in rtp_defines.h.
+ OSP_DCHECK(oldest_slot.is_active_for_frame(checkpoint_frame_id_ + 1));
+
+ return (next_frame_rtp_timestamp - oldest_slot.frame->rtp_timestamp)
+ .ToDuration<Clock::duration>(rtp_timebase_);
+}
+
+Clock::duration Sender::GetMaxInFlightMediaDuration() const {
+ // Assumption: The total amount of allowed in-flight media should equal the
+ // half of the playout delay window, plus the amount of time it takes to
+ // receive an ACK from the Receiver.
+ //
+ // Why half of the playout delay window? It's assumed here that capture and
+ // media encoding, which occur before EnqueueFrame() is called, are executing
+ // within the first half of the playout delay window. This leaves the second
+ // half for executing all network transmits/re-transmits, plus decoding and
+ // play-out at the Receiver.
+ return (target_playout_delay_ / 2) + (round_trip_time_ / 2);
+}
+
+bool Sender::NeedsKeyFrame() const {
+ return last_enqueued_key_frame_id_ <= picture_lost_at_frame_id_;
+}
+
+FrameId Sender::GetNextFrameId() const {
+ return last_enqueued_frame_id_ + 1;
+}
+
+Sender::EnqueueFrameResult Sender::EnqueueFrame(const EncodedFrame& frame) {
+ // Assume the fields of the |frame| have all been set correctly, with
+ // monotonically increasing timestamps and a valid pointer to the data.
+ OSP_DCHECK_EQ(frame.frame_id, GetNextFrameId());
+ OSP_DCHECK_GE(frame.referenced_frame_id, FrameId::first());
+ if (frame.frame_id != FrameId::first()) {
+ OSP_DCHECK_GT(frame.rtp_timestamp, pending_sender_report_.rtp_timestamp);
+ OSP_DCHECK_GT(frame.reference_time, pending_sender_report_.reference_time);
+ }
+ OSP_DCHECK(frame.data.data());
+
+ // Check whether enqueuing the frame would exceed the design limit for the
+ // span of FrameIds. Even if |num_frames_in_flight_| is less than
+ // kMaxUnackedFrames, it's the span of FrameIds that is restricted.
+ if ((frame.frame_id - checkpoint_frame_id_) > kMaxUnackedFrames) {
+ return REACHED_ID_SPAN_LIMIT;
+ }
+
+ // Check whether enqueuing the frame would exceed the current maximum media
+ // duration limit.
+ if (GetInFlightMediaDuration(frame.rtp_timestamp) >
+ GetMaxInFlightMediaDuration()) {
+ return MAX_DURATION_IN_FLIGHT;
+ }
+
+ // Encrypt the frame and initialize the slot tracking its sending.
+ PendingFrameSlot* const slot = get_slot_for(frame.frame_id);
+ OSP_DCHECK(!slot->frame);
+ slot->frame = crypto_.Encrypt(frame);
+ const int packet_count = rtp_packetizer_.ComputeNumberOfPackets(*slot->frame);
+ if (packet_count <= 0) {
+ slot->frame.reset();
+ return PAYLOAD_TOO_LARGE;
+ }
+ slot->send_flags.Resize(packet_count, YetAnotherBitVector::SET);
+ slot->packet_sent_times.assign(packet_count, SenderPacketRouter::kNever);
+
+ // Officially record the "enqueue."
+ ++num_frames_in_flight_;
+ last_enqueued_frame_id_ = slot->frame->frame_id;
+ OSP_DCHECK_LE(num_frames_in_flight_,
+ last_enqueued_frame_id_ - checkpoint_frame_id_);
+ if (slot->frame->dependency == EncodedFrame::KEY_FRAME) {
+ last_enqueued_key_frame_id_ = slot->frame->frame_id;
+ }
+
+ // Update the target playout delay, if necessary.
+ if (slot->frame->new_playout_delay > milliseconds::zero()) {
+ target_playout_delay_ = slot->frame->new_playout_delay;
+ playout_delay_change_at_frame_id_ = slot->frame->frame_id;
+ }
+
+ // Update the lip-sync information for the next Sender Report.
+ pending_sender_report_.reference_time = slot->frame->reference_time;
+ pending_sender_report_.rtp_timestamp = slot->frame->rtp_timestamp;
+
+ // If the round trip time hasn't been computed yet, immediately send a RTCP
+ // packet (i.e., before the RTP packets are sent). The RTCP packet will
+ // provide a Sender Report which contains the required lip-sync information
+ // the Receiver needs for timing the media playout.
+ //
+ // Detail: Working backwards, if the round trip time is not known, then this
+ // Sender has never processed a Receiver Report. Thus, the Receiver has never
+ // provided a Receiver Report, which it can only do after having processed a
+ // Sender Report from this Sender. Thus, this Sender really needs to send
+ // that, right now!
+ if (round_trip_time_ == Clock::duration::zero()) {
+ packet_router_->RequestRtcpSend(rtcp_session_.receiver_ssrc());
+ }
+
+ // Re-activate RTP sending if it was suspended.
+ packet_router_->RequestRtpSend(rtcp_session_.receiver_ssrc());
+
+ return OK;
+}
+
+void Sender::OnReceivedRtcpPacket(Clock::time_point arrival_time,
+ absl::Span<const uint8_t> packet) {
+ rtcp_packet_arrival_time_ = arrival_time;
+ // This call to Parse() invoke zero or more of the OnReceiverXYZ() methods in
+ // the current call stack:
+ if (rtcp_parser_.Parse(packet, last_enqueued_frame_id_)) {
+ packet_router_->OnRtcpReceived(arrival_time, round_trip_time_);
+ }
+}
+
+absl::Span<uint8_t> Sender::GetRtcpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) {
+ if (pending_sender_report_.reference_time == SenderPacketRouter::kNever) {
+ // Cannot send a report if one is not available (i.e., a frame has never
+ // been enqueued).
+ return buffer.subspan(0, 0);
+ }
+
+ // The Sender Report to be sent is a snapshot of the "pending Sender Report,"
+ // but with its timestamp fields modified. First, the reference time is set to
+ // the RTCP packet's send time. Then, the corresponding RTP timestamp is
+ // translated to match (for lip-sync).
+ RtcpSenderReport sender_report = pending_sender_report_;
+ sender_report.reference_time = send_time;
+ sender_report.rtp_timestamp += RtpTimeDelta::FromDuration(
+ sender_report.reference_time - pending_sender_report_.reference_time,
+ rtp_timebase_);
+
+ return sender_report_builder_.BuildPacket(sender_report, buffer).first;
+}
+
+absl::Span<uint8_t> Sender::GetRtpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) {
+ ChosenPacket chosen = ChooseNextRtpPacketNeedingSend();
+
+ // If no packets need sending (i.e., all packets have been sent at least once
+ // and do not need to be re-sent yet), check whether a Kickstart packet should
+ // be sent. It's possible that there has been complete packet loss of some
+ // frames, and the Receiver may not be aware of the existence of the latest
+ // frame(s). Kickstarting is the only way the Receiver can discover the newer
+ // frames it doesn't know about.
+ if (!chosen) {
+ const ChosenPacketAndWhen kickstart = ChooseKickstartPacket();
+ if (kickstart.when > send_time) {
+ // Nothing to send, so return "empty" signal to the packet router. The
+ // packet router will suspend RTP sending until this Sender explicitly
+ // resumes it.
+ return buffer.subspan(0, 0);
+ }
+ chosen = kickstart;
+ OSP_DCHECK(chosen);
+ }
+
+ const absl::Span<uint8_t> result = rtp_packetizer_.GeneratePacket(
+ *chosen.slot->frame, chosen.packet_id, buffer);
+ chosen.slot->send_flags.Clear(chosen.packet_id);
+ chosen.slot->packet_sent_times[chosen.packet_id] = send_time;
+
+ ++pending_sender_report_.send_packet_count;
+ // According to RFC3550, the octet count does not include the RTP header. The
+ // following is just a good approximation, however, because the header size
+ // will very infrequently be 4 bytes greater (see
+ // RtpPacketizer::kAdaptiveLatencyHeaderSize). No known Cast Streaming
+ // Receiver implementations use this for anything, and so this should be fine.
+ const int approximate_octet_count =
+ static_cast<int>(result.size()) - RtpPacketizer::kBaseRtpHeaderSize;
+ OSP_DCHECK_GE(approximate_octet_count, 0);
+ pending_sender_report_.send_octet_count += approximate_octet_count;
+
+ return result;
+}
+
+Clock::time_point Sender::GetRtpResumeTime() {
+ if (ChooseNextRtpPacketNeedingSend()) {
+ return Alarm::kImmediately;
+ }
+ return ChooseKickstartPacket().when;
+}
+
+void Sender::OnReceiverReferenceTimeAdvanced(Clock::time_point reference_time) {
+ // Not used.
+}
+
+void Sender::OnReceiverReport(const RtcpReportBlock& receiver_report) {
+ OSP_DCHECK_NE(rtcp_packet_arrival_time_, SenderPacketRouter::kNever);
+
+ const Clock::duration total_delay =
+ rtcp_packet_arrival_time_ -
+ sender_report_builder_.GetRecentReportTime(
+ receiver_report.last_status_report_id, rtcp_packet_arrival_time_);
+ const auto non_network_delay =
+ duration_cast<Clock::duration>(receiver_report.delay_since_last_report);
+
+ // Round trip time measurement: This is the time elapsed since the Sender
+ // Report was sent, minus the time the Receiver did other stuff before sending
+ // the Receiver Report back.
+ //
+ // If the round trip time seems to be less than or equal to zero, assume clock
+ // imprecision by one or both peers caused a bad value to be calculated. The
+ // true value is likely very close to zero (i.e., this is ideal network
+ // behavior); and so just represent this as 75 µs, an optimistic
+ // wired-Ethernet LAN ping time.
+ constexpr auto kNearZeroRoundTripTime =
+ duration_cast<Clock::duration>(microseconds(75));
+ static_assert(kNearZeroRoundTripTime > Clock::duration::zero(),
+ "More precision in Clock::duration needed!");
+ const Clock::duration measurement =
+ std::max(total_delay - non_network_delay, kNearZeroRoundTripTime);
+
+ // Validate the measurement by using the current target playout delay as a
+ // "reasonable upper-bound." It's certainly possible that the actual network
+ // round-trip time could exceed the target playout delay, but that would mean
+ // the current network performance is totally inadequate for streaming anyway.
+ if (measurement > target_playout_delay_) {
+ OSP_LOG_WARN << "Invalidating a round-trip time measurement ("
+ << measurement
+ << ") since it exceeds the current target playout delay ("
+ << target_playout_delay_ << ").";
+ return;
+ }
+
+ // Measurements will typically have high variance. Use a simple smoothing
+ // filter to track a short-term average that changes less drastically.
+ if (round_trip_time_ == Clock::duration::zero()) {
+ round_trip_time_ = measurement;
+ } else {
+ // Arbitrary constant, to provide 1/8 weight to the new measurement, and 7/8
+ // weight to the old estimate, which seems to work well for de-noising the
+ // estimate.
+ constexpr int kInertia = 7;
+ round_trip_time_ =
+ (kInertia * round_trip_time_ + measurement) / (kInertia + 1);
+ }
+ // TODO(miu): Add tracing event here to note the updated RTT.
+}
+
+void Sender::OnReceiverIndicatesPictureLoss() {
+ // The Receiver will continue the PLI notifications until it has received a
+ // key frame. Thus, if a key frame is already in-flight, don't make a state
+ // change that would cause this Sender to force another expensive key frame.
+ if (checkpoint_frame_id_ < last_enqueued_key_frame_id_) {
+ return;
+ }
+
+ picture_lost_at_frame_id_ = checkpoint_frame_id_;
+
+ if (observer_) {
+ observer_->OnPictureLost();
+ }
+
+ // Note: It may seem that all pending frames should be canceled until
+ // EnqueueFrame() is called with a key frame. However:
+ //
+ // 1. The Receiver should still be the main authority on what frames/packets
+ // are being ACK'ed and NACK'ed.
+ //
+ // 2. It may be desirable for the Receiver to be "limping along" in the
+ // meantime. For example, video may be corrupted but mostly watchable,
+ // and so it's best for the Sender to continue sending the non-key frames
+ // until the Receiver indicates otherwise.
+}
+
+void Sender::OnReceiverCheckpoint(FrameId frame_id,
+ milliseconds playout_delay) {
+ if (frame_id > last_enqueued_frame_id_) {
+ OSP_LOG_ERROR
+ << "Ignoring checkpoint for " << latest_expected_frame_id_
+ << " because this Sender could not have sent any frames after "
+ << last_enqueued_frame_id_ << '.';
+ return;
+ }
+ // CompoundRtcpParser should guarantee this:
+ OSP_DCHECK(playout_delay >= milliseconds::zero());
+
+ while (checkpoint_frame_id_ < frame_id) {
+ ++checkpoint_frame_id_;
+ CancelPendingFrame(checkpoint_frame_id_);
+ }
+ latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id);
+
+ if (playout_delay != target_playout_delay_ &&
+ frame_id >= playout_delay_change_at_frame_id_) {
+ OSP_LOG_WARN << "Sender's target playout delay (" << target_playout_delay_
+ << ") disagrees with the Receiver's (" << playout_delay << ")";
+ }
+}
+
+void Sender::OnReceiverHasFrames(std::vector<FrameId> acks) {
+ OSP_DCHECK(!acks.empty() && AreElementsSortedAndUnique(acks));
+
+ if (acks.back() > last_enqueued_frame_id_) {
+ OSP_LOG_ERROR << "Ignoring individual frame ACKs: ACKing frame "
+ << latest_expected_frame_id_
+ << " is invalid because this Sender could not have sent any "
+ "frames after "
+ << last_enqueued_frame_id_ << '.';
+ return;
+ }
+
+ for (FrameId id : acks) {
+ CancelPendingFrame(id);
+ }
+ latest_expected_frame_id_ = std::max(latest_expected_frame_id_, acks.back());
+}
+
+void Sender::OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) {
+ OSP_DCHECK(!nacks.empty() && AreElementsSortedAndUnique(nacks));
+ OSP_DCHECK_NE(rtcp_packet_arrival_time_, SenderPacketRouter::kNever);
+
+ // This is a point-in-time threshold that indicates whether each NACK will
+ // trigger a packet retransmit. The threshold is based on the network round
+ // trip time because a Receiver's NACK may have been issued while the needed
+ // packet was in-flight from the Sender. In such cases, the Receiver's NACK is
+ // likely stale and this Sender should not redundantly re-transmit the packet
+ // again.
+ const Clock::time_point too_recent_a_send_time =
+ rtcp_packet_arrival_time_ - round_trip_time_;
+
+ // Iterate over all the NACKs...
+ bool need_to_send = false;
+ for (auto nack_it = nacks.begin(); nack_it != nacks.end();) {
+ // Find the slot associated with the NACK's frame ID.
+ const FrameId frame_id = nack_it->frame_id;
+ PendingFrameSlot* slot = nullptr;
+ if (frame_id <= last_enqueued_frame_id_) {
+ PendingFrameSlot* const candidate_slot = get_slot_for(frame_id);
+ if (candidate_slot->is_active_for_frame(frame_id)) {
+ slot = candidate_slot;
+ }
+ }
+
+ // If no slot was found (i.e., the NACK is invalid) for the frame, skip-over
+ // all other NACKs for the same frame. While it seems to be a bug that the
+ // Receiver would attempt to NACK a frame that does not yet exist, this can
+ // happen in rare cases where RTCP packets arrive out-of-order (i.e., the
+ // network shuffled them).
+ if (!slot) {
+ // TODO(miu): Add tracing event here to record this.
+ for (++nack_it; nack_it != nacks.end() && nack_it->frame_id == frame_id;
+ ++nack_it)
+ ;
+ continue;
+ }
+
+ latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id);
+
+ const auto HandleIndividualNack = [&](FramePacketId packet_id) {
+ if (slot->packet_sent_times[packet_id] <= too_recent_a_send_time) {
+ slot->send_flags.Set(packet_id);
+ need_to_send = true;
+ }
+ };
+ const FramePacketId range_end = slot->packet_sent_times.size();
+ if (nack_it->packet_id == kAllPacketsLost) {
+ for (FramePacketId packet_id = 0; packet_id < range_end; ++packet_id) {
+ HandleIndividualNack(packet_id);
+ }
+ ++nack_it;
+ } else {
+ do {
+ if (nack_it->packet_id < range_end) {
+ HandleIndividualNack(nack_it->packet_id);
+ } else {
+ OSP_LOG_WARN
+ << "Ignoring NACK for packet that doesn't exist in frame "
+ << frame_id << ": " << static_cast<int>(nack_it->packet_id);
+ }
+ ++nack_it;
+ } while (nack_it != nacks.end() && nack_it->frame_id == frame_id);
+ }
+ }
+
+ if (need_to_send) {
+ packet_router_->RequestRtpSend(rtcp_session_.receiver_ssrc());
+ }
+}
+
+Sender::ChosenPacket Sender::ChooseNextRtpPacketNeedingSend() {
+ // Find the oldest packet needing to be sent (or re-sent).
+ for (FrameId frame_id = checkpoint_frame_id_ + 1;
+ frame_id <= last_enqueued_frame_id_; ++frame_id) {
+ PendingFrameSlot* const slot = get_slot_for(frame_id);
+ if (!slot->is_active_for_frame(frame_id)) {
+ continue; // Frame was canceled. None of its packets need to be sent.
+ }
+ const FramePacketId packet_id = slot->send_flags.FindFirstSet();
+ if (packet_id < slot->send_flags.size()) {
+ return {slot, packet_id};
+ }
+ }
+
+ return {}; // Nothing needs to be sent.
+}
+
+Sender::ChosenPacketAndWhen Sender::ChooseKickstartPacket() {
+ if (latest_expected_frame_id_ >= last_enqueued_frame_id_) {
+ // Since the Receiver must know about all of the frames currently queued, no
+ // Kickstart packet is necessary.
+ return {};
+ }
+
+ // The Kickstart packet is always in the last-enqueued frame, so that the
+ // Receiver will know about every frame the Sender has. However, which packet
+ // should be chosen? Any would do, since all packets contain the frame's total
+ // packet count. For historical reasons, all sender implementations have
+ // always just sent the last packet; and so that tradition is continued here.
+ ChosenPacketAndWhen chosen;
+ chosen.slot = get_slot_for(last_enqueued_frame_id_);
+ // Note: This frame cannot have been canceled since
+ // |latest_expected_frame_id_| hasn't yet reached this point.
+ OSP_DCHECK(chosen.slot->is_active_for_frame(last_enqueued_frame_id_));
+ chosen.packet_id = chosen.slot->send_flags.size() - 1;
+
+ const Clock::time_point time_last_sent =
+ chosen.slot->packet_sent_times[chosen.packet_id];
+ // Sanity-check: This method should not be called to choose a packet while
+ // there are still unsent packets.
+ OSP_DCHECK_NE(time_last_sent, SenderPacketRouter::kNever);
+
+ // The desired Kickstart interval is a fraction of the total
+ // |target_playout_delay_|. The reason for the specific ratio here is based on
+ // lost knowledge (from legacy implementations); but it makes sense (i.e., to
+ // be a good "network citizen") to be less aggressive for larger playout delay
+ // windows, and more aggressive for shorter ones to avoid too-late packet
+ // arrivals.
+ using kWaitFraction = std::ratio<1, 20>;
+ const Clock::duration desired_kickstart_interval =
+ duration_cast<Clock::duration>(target_playout_delay_) *
+ kWaitFraction::num / kWaitFraction::den;
+ // The actual interval used is increased, if current network performance
+ // warrants waiting longer. Don't send a Kickstart packet until no NACKs
+ // have been received for two network round-trip periods.
+ constexpr int kLowerBoundRoundTrips = 2;
+ const Clock::duration kickstart_interval = std::max(
+ desired_kickstart_interval, round_trip_time_ * kLowerBoundRoundTrips);
+ chosen.when = time_last_sent + kickstart_interval;
+
+ return chosen;
+}
+
+void Sender::CancelPendingFrame(FrameId frame_id) {
+ PendingFrameSlot* const slot = get_slot_for(frame_id);
+ if (!slot->is_active_for_frame(frame_id)) {
+ return; // Frame was already canceled.
+ }
+
+ packet_router_->OnPayloadReceived(
+ slot->frame->data.size(), rtcp_packet_arrival_time_, round_trip_time_);
+
+ slot->frame.reset();
+ OSP_DCHECK_GT(num_frames_in_flight_, 0);
+ --num_frames_in_flight_;
+ if (observer_) {
+ observer_->OnFrameCanceled(frame_id);
+ }
+}
+
+void Sender::Observer::OnFrameCanceled(FrameId frame_id) {}
+void Sender::Observer::OnPictureLost() {}
+Sender::Observer::~Observer() = default;
+
+Sender::PendingFrameSlot::PendingFrameSlot() = default;
+Sender::PendingFrameSlot::~PendingFrameSlot() = default;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.h b/chromium/third_party/openscreen/src/cast/streaming/sender.h
new file mode 100644
index 00000000000..48f651c04de
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender.h
@@ -0,0 +1,318 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_SENDER_H_
+#define CAST_STREAMING_SENDER_H_
+
+#include <stdint.h>
+
+#include <array>
+#include <chrono> // NOLINT
+#include <vector>
+
+#include "absl/types/span.h"
+#include "cast/streaming/compound_rtcp_parser.h"
+#include "cast/streaming/constants.h"
+#include "cast/streaming/frame_crypto.h"
+#include "cast/streaming/frame_id.h"
+#include "cast/streaming/rtp_defines.h"
+#include "cast/streaming/rtp_packetizer.h"
+#include "cast/streaming/rtp_time.h"
+#include "cast/streaming/sender_packet_router.h"
+#include "cast/streaming/sender_report_builder.h"
+#include "platform/api/time.h"
+#include "util/yet_another_bit_vector.h"
+
+namespace openscreen {
+namespace cast {
+
+class Environment;
+struct SessionConfig;
+
+// The Cast Streaming Sender, a peer corresponding to some Cast Streaming
+// Receiver at the other end of a network link. See class level comments for
+// Receiver for a high-level overview.
+//
+// The Sender is the peer responsible for enqueuing EncodedFrames for streaming,
+// guaranteeing their delivery to a Receiver, and handling feedback events from
+// a Receiver. Some feedback events are used for managing the Sender's internal
+// queue of in-flight frames, requesting network packet re-transmits, etc.;
+// while others are exposed via the Sender's public interface. For example,
+// sometimes the Receiver signals that it needs a a key frame to resolve a
+// picture loss condition, and the modules upstream of the Sender (e.g., where
+// encoding happens) should call NeedsKeyFrame() to check for, and handle that.
+//
+// There are usually one or two Senders in a streaming session, one for audio
+// and one for video. Both senders work with the same SenderPacketRouter
+// instance to schedule their transmission of packets, and provide the necessary
+// metrics for estimating bandwidth utilization and availability.
+//
+// It is the responsibility of upstream code modules to handle congestion
+// control. With respect to this Sender, that means the media encoding bit rate
+// should be throttled based on network bandwidth availability. This Sender does
+// not do any throttling, only flow-control. In other words, this Sender can
+// only manage its in-flight queue of frames, and if that queue grows too large,
+// it will eventually reject further enqueuing.
+//
+// General usage: A client should check the in-flight media duration frequently
+// to decide when to pause encoding, to avoid wasting system resources on
+// encoding frames that will likely be rejected by the Sender. The client should
+// also frequently call NeedsKeyFrame() and, when this returns true, direct its
+// encoder to produce a key frame soon. Finally, when using EnqueueFrame(), an
+// EncodedFrame struct should be prepared with its frame_id field set to
+// whatever GetNextFrameId() returns. Please see method comments for
+// more-detailed usage info.
+class Sender final : public SenderPacketRouter::Sender,
+ public CompoundRtcpParser::Client {
+ public:
+ // Interface for receiving notifications about events of possible interest.
+ // Handling each of these is optional, but some may be mandatory for certain
+ // applications (see method comments below).
+ class Observer {
+ public:
+ // Called when a frame was canceled. "Canceled" means that the Receiver has
+ // either acknowledged successful receipt of the frame or has decided to
+ // skip over it. Note: Frame cancellations may occur out-of-order.
+ virtual void OnFrameCanceled(FrameId frame_id);
+
+ // Called when a Receiver begins reporting picture loss, and there is no key
+ // frame currently enqueued in the Sender. The application should enqueue a
+ // key frame as soon as possible. Note: An application that pauses frame
+ // sending (e.g., screen mirroring when the screen is not changing) should
+ // use this notification to send an out-of-band "refresh frame," encoded as
+ // a key frame.
+ virtual void OnPictureLost();
+
+ protected:
+ virtual ~Observer();
+ };
+
+ // Result codes for EnqueueFrame().
+ enum EnqueueFrameResult {
+ // The frame has been queued for sending.
+ OK,
+
+ // The frame's payload was too large. This is typically triggered when
+ // submitting a payload of several dozen megabytes or more. This result code
+ // likely indicates some kind of upstream bug.
+ PAYLOAD_TOO_LARGE,
+
+ // The span of FrameIds is too large. Cast Streaming's protocol design
+ // imposes a limit in the maximum difference between the highest-valued
+ // in-flight FrameId and the least-valued one.
+ REACHED_ID_SPAN_LIMIT,
+
+ // Too-large a media duration is in-flight. Enqueuing another frame would
+ // automatically cause late play-out at the Receiver.
+ MAX_DURATION_IN_FLIGHT,
+ };
+
+ // Constructs a Sender that attaches to the given |environment|-provided
+ // resources and |packet_router|. The |config| contains the settings that were
+ // agreed-upon by both sides from the OFFER/ANSWER exchange (i.e., the part of
+ // the overall end-to-end connection process that occurs before Cast Streaming
+ // is started). The |rtp_payload_type| does not affect the behavior of this
+ // Sender. It is simply passed along to a Receiver in the RTP packet stream.
+ Sender(Environment* environment,
+ SenderPacketRouter* packet_router,
+ const SessionConfig& config,
+ RtpPayloadType rtp_payload_type);
+
+ ~Sender() final;
+
+ Ssrc ssrc() const { return rtcp_session_.sender_ssrc(); }
+ int rtp_timebase() const { return rtp_timebase_; }
+
+ // Sets an observer for receiving notifications. Call with nullptr to stop
+ // observing.
+ void SetObserver(Observer* observer);
+
+ // Returns the number of frames currently in-flight. This is only meant to be
+ // informative. Clients should use GetInFlightMediaDuration() to make
+ // throttling decisions.
+ int GetInFlightFrameCount() const;
+
+ // Returns the total media duration of the frames currently in-flight,
+ // assuming the next not-yet-enqueued frame will have the given RTP timestamp.
+ // For a better user experience, the result should be compared to
+ // GetMaxInFlightMediaDuration(), and media encoding should be throttled down
+ // before additional EnqueueFrame() calls would cause this to reach the
+ // current maximum limit.
+ Clock::duration GetInFlightMediaDuration(
+ RtpTimeTicks next_frame_rtp_timestamp) const;
+
+ // Return the maximum acceptable in-flight media duration, given the current
+ // target playout delay setting and end-to-end network/system conditions.
+ Clock::duration GetMaxInFlightMediaDuration() const;
+
+ // Returns true if the Receiver requires a key frame. Note that this will
+ // return true until a key frame is accepted by EnqueueFrame(). Thus, when
+ // encoding is pipelined, care should be taken to instruct the encoder to
+ // produce just ONE forced key frame.
+ bool NeedsKeyFrame() const;
+
+ // Returns the next FrameId, the one after the frame enqueued by the last call
+ // to EnqueueFrame(). Note that the next call to EnqueueFrame() assumes this
+ // frame ID be used.
+ FrameId GetNextFrameId() const;
+
+ // Enqueues the given |frame| for sending as soon as possible. Returns OK if
+ // the frame is accepted, and some time later Observer::OnFrameCanceled() will
+ // be called once it is no longer in-flight.
+ //
+ // All fields of the |frame| must be set to valid values: the |frame_id| must
+ // be the same as GetNextFrameId(); both the |rtp_timestamp| and
+ // |reference_time| fields must be monotonically increasing relative to the
+ // prior frame; and the frame's |data| pointer must be set.
+ [[nodiscard]] EnqueueFrameResult EnqueueFrame(const EncodedFrame& frame);
+
+ private:
+ // Tracking/Storage for frames that are ready-to-send, and until they are
+ // fully received at the other end.
+ struct PendingFrameSlot {
+ // The frame to send, or nullopt if this slot is not in use.
+ absl::optional<EncryptedFrame> frame;
+
+ // Represents which packets need to be sent. Elements are indexed by
+ // FramePacketId. A set bit means a packet needs to be sent (or re-sent).
+ YetAnotherBitVector send_flags;
+
+ // The time when each of the packets was last sent, or
+ // |SenderPacketRouter::kNever| if the packet has not been sent yet.
+ // Elements are indexed by FramePacketId. This is used to avoid
+ // re-transmitting any given packet too frequently.
+ std::vector<Clock::time_point> packet_sent_times;
+
+ PendingFrameSlot();
+ ~PendingFrameSlot();
+
+ bool is_active_for_frame(FrameId frame_id) const {
+ return frame && frame->frame_id == frame_id;
+ }
+ };
+
+ // Return value from the ChooseXYZ() helper methods.
+ struct ChosenPacket {
+ PendingFrameSlot* slot = nullptr;
+ FramePacketId packet_id{};
+
+ explicit operator bool() const { return !!slot; }
+ };
+
+ // An extension of ChosenPacket that also includes the point-in-time when the
+ // packet should be sent.
+ struct ChosenPacketAndWhen : public ChosenPacket {
+ Clock::time_point when = SenderPacketRouter::kNever;
+ };
+
+ // SenderPacketRouter::Sender implementation.
+ void OnReceivedRtcpPacket(Clock::time_point arrival_time,
+ absl::Span<const uint8_t> packet) final;
+ absl::Span<uint8_t> GetRtcpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) final;
+ absl::Span<uint8_t> GetRtpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) final;
+ Clock::time_point GetRtpResumeTime() final;
+
+ // CompoundRtcpParser::Client implementation.
+ void OnReceiverReferenceTimeAdvanced(Clock::time_point reference_time) final;
+ void OnReceiverReport(const RtcpReportBlock& receiver_report) final;
+ void OnReceiverIndicatesPictureLoss() final;
+ void OnReceiverCheckpoint(FrameId frame_id,
+ std::chrono::milliseconds playout_delay) final;
+ void OnReceiverHasFrames(std::vector<FrameId> acks) final;
+ void OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) final;
+
+ // Helper to choose which packet to send, from those that have been flagged as
+ // "need to send." Returns a "false" result if nothing needs to be sent.
+ ChosenPacket ChooseNextRtpPacketNeedingSend();
+
+ // Helper that returns the packet that should be used to kick-start the
+ // Receiver, and the time at which the packet should be sent. Returns a kNever
+ // result if kick-starting is not needed.
+ ChosenPacketAndWhen ChooseKickstartPacket();
+
+ // Cancels the given frame once it is known to have been fully received (i.e.,
+ // based on the ACK feedback from the Receiver in a RTCP packet). This clears
+ // the corresponding entry in |pending_frames_| and notifies the Observer.
+ void CancelPendingFrame(FrameId frame_id);
+
+ // Inline helper to return the slot that would contain the tracking info for
+ // the given |frame_id|.
+ const PendingFrameSlot* get_slot_for(FrameId frame_id) const {
+ return &pending_frames_[(frame_id - FrameId::first()) %
+ pending_frames_.size()];
+ }
+ PendingFrameSlot* get_slot_for(FrameId frame_id) {
+ return &pending_frames_[(frame_id - FrameId::first()) %
+ pending_frames_.size()];
+ }
+
+ SenderPacketRouter* const packet_router_;
+ RtcpSession rtcp_session_;
+ CompoundRtcpParser rtcp_parser_;
+ SenderReportBuilder sender_report_builder_;
+ RtpPacketizer rtp_packetizer_;
+ const int rtp_timebase_;
+ FrameCrypto crypto_;
+
+ // Ring buffer of PendingFrameSlots. The frame having FrameId x will always
+ // be slotted at position x % pending_frames_.size(). Use get_slot_for() to
+ // access the correct slot for a given FrameId.
+ std::array<PendingFrameSlot, kMaxUnackedFrames> pending_frames_{};
+
+ // A count of the number of frames in-flight (i.e., the number of active
+ // entries in |pending_frames_|).
+ int num_frames_in_flight_ = 0;
+
+ // The ID of the last frame enqueued.
+ FrameId last_enqueued_frame_id_ = FrameId::leader();
+
+ // Indicates that all of the packets for all frames up to and including this
+ // FrameId have been successfully received (or otherwise do not need to be
+ // re-transmitted).
+ FrameId checkpoint_frame_id_ = FrameId::leader();
+
+ // The ID of the latest frame the Receiver seems to be aware of.
+ FrameId latest_expected_frame_id_ = FrameId::leader();
+
+ // The target playout delay for the last-enqueued frame. This is auto-updated
+ // when a frame is enqueued that changes the delay.
+ std::chrono::milliseconds target_playout_delay_;
+ FrameId playout_delay_change_at_frame_id_ = FrameId::first();
+
+ // The exact arrival time of the last RTCP packet.
+ Clock::time_point rtcp_packet_arrival_time_ = SenderPacketRouter::kNever;
+
+ // The near-term average round trip time. This is updated with each Sender
+ // Report → Receiver Report round trip. This is initially zero, indicating the
+ // round trip time has not been measured yet.
+ Clock::duration round_trip_time_{0};
+
+ // Maintain current stats in a Sender Report that is ready for sending at any
+ // time. This includes up-to-date lip-sync information, and packet and byte
+ // count stats.
+ RtcpSenderReport pending_sender_report_;
+
+ // These are used to determine whether a key frame needs to be sent to the
+ // Receiver. When the Receiver provides a picture loss notification, the
+ // current checkpoint frame ID is stored in |picture_lost_at_frame_id_|. Then,
+ // while |last_enqueued_key_frame_id_| is less than or equal to
+ // |picture_lost_at_frame_id_|, the Sender knows it still needs to send a key
+ // frame to resolve the picture loss condition. In all other cases, the
+ // Receiver is either in a good state or is in the process of receiving the
+ // key frame that will make that happen.
+ FrameId picture_lost_at_frame_id_ = FrameId::leader();
+ FrameId last_enqueued_key_frame_id_ = FrameId::leader();
+
+ // The current observer (optional).
+ Observer* observer_ = nullptr;
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_SENDER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc
new file mode 100644
index 00000000000..870b085a861
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc
@@ -0,0 +1,274 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/sender_packet_router.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "cast/streaming/constants.h"
+#include "cast/streaming/packet_util.h"
+#include "util/logging.h"
+#include "util/saturate_cast.h"
+#include "util/stringprintf.h"
+
+namespace openscreen {
+namespace cast {
+
+using std::chrono::duration_cast;
+using std::chrono::milliseconds;
+using std::chrono::seconds;
+
+SenderPacketRouter::SenderPacketRouter(Environment* environment,
+ int max_burst_bitrate)
+ : SenderPacketRouter(
+ environment,
+ ComputeMaxPacketsPerBurst(max_burst_bitrate,
+ environment->GetMaxPacketSize(),
+ kDefaultBurstInterval),
+ kDefaultBurstInterval) {}
+
+SenderPacketRouter::SenderPacketRouter(Environment* environment,
+ int max_packets_per_burst,
+ milliseconds burst_interval)
+ : BandwidthEstimator(max_packets_per_burst,
+ burst_interval,
+ environment->now()),
+ environment_(environment),
+ packet_buffer_size_(environment->GetMaxPacketSize()),
+ packet_buffer_(new uint8_t[packet_buffer_size_]),
+ max_packets_per_burst_(max_packets_per_burst),
+ burst_interval_(burst_interval),
+ max_burst_bitrate_(ComputeMaxBurstBitrate(packet_buffer_size_,
+ max_packets_per_burst_,
+ burst_interval_)),
+ alarm_(environment_->now_function(), environment_->task_runner()) {
+ OSP_DCHECK(environment_);
+ OSP_DCHECK_GT(packet_buffer_size_, kRequiredNetworkPacketSize);
+}
+
+SenderPacketRouter::~SenderPacketRouter() {
+ OSP_DCHECK(senders_.empty());
+}
+
+void SenderPacketRouter::OnSenderCreated(Ssrc receiver_ssrc, Sender* sender) {
+ OSP_DCHECK(FindEntry(receiver_ssrc) == senders_.end());
+ senders_.push_back(SenderEntry{receiver_ssrc, sender, kNever, kNever});
+
+ if (senders_.size() == 1) {
+ environment_->ConsumeIncomingPackets(this);
+ } else {
+ // Sort the list of Senders so that they are iterated in priority order.
+ std::sort(senders_.begin(), senders_.end());
+ }
+}
+
+void SenderPacketRouter::OnSenderDestroyed(Ssrc receiver_ssrc) {
+ const auto it = FindEntry(receiver_ssrc);
+ OSP_DCHECK(it != senders_.end());
+ senders_.erase(it);
+
+ // If there are no longer any Senders, suspend receiving RTCP packets.
+ if (senders_.empty()) {
+ environment_->DropIncomingPackets();
+ }
+}
+
+void SenderPacketRouter::RequestRtcpSend(Ssrc receiver_ssrc) {
+ const auto it = FindEntry(receiver_ssrc);
+ OSP_DCHECK(it != senders_.end());
+ it->next_rtcp_send_time = Alarm::kImmediately;
+ ScheduleNextBurst();
+}
+
+void SenderPacketRouter::RequestRtpSend(Ssrc receiver_ssrc) {
+ const auto it = FindEntry(receiver_ssrc);
+ OSP_DCHECK(it != senders_.end());
+ it->next_rtp_send_time = Alarm::kImmediately;
+ ScheduleNextBurst();
+}
+
+void SenderPacketRouter::OnReceivedPacket(const IPEndpoint& source,
+ Clock::time_point arrival_time,
+ std::vector<uint8_t> packet) {
+ // If the packet did not come from the expected endpoint, ignore it.
+ OSP_DCHECK_NE(source.port, uint16_t{0});
+ if (source != environment_->remote_endpoint()) {
+ return;
+ }
+
+ // Determine which Sender to dispatch the packet to. Senders may only receive
+ // RTCP packets from Receivers. Log a warning containing a pretty-printed dump
+ // if the packet is not an RTCP packet.
+ const std::pair<ApparentPacketType, Ssrc> seems_like =
+ InspectPacketForRouting(packet);
+ if (seems_like.first != ApparentPacketType::RTCP) {
+ constexpr int kMaxPartiaHexDumpSize = 96;
+ OSP_LOG_WARN << "UNKNOWN packet of " << packet.size()
+ << " bytes. Partial hex dump: "
+ << HexEncode(absl::Span<const uint8_t>(packet).subspan(
+ 0, kMaxPartiaHexDumpSize));
+ return;
+ }
+ const auto it = FindEntry(seems_like.second);
+ if (it != senders_.end()) {
+ it->sender->OnReceivedRtcpPacket(arrival_time, std::move(packet));
+ }
+}
+
+SenderPacketRouter::SenderEntries::iterator SenderPacketRouter::FindEntry(
+ Ssrc receiver_ssrc) {
+ return std::find_if(senders_.begin(), senders_.end(),
+ [receiver_ssrc](const SenderEntry& entry) {
+ return entry.receiver_ssrc == receiver_ssrc;
+ });
+}
+
+void SenderPacketRouter::ScheduleNextBurst() {
+ // Determine the next burst time by scanning for the earliest of the
+ // next-scheduled send times for each Sender.
+ const Clock::time_point earliest_allowed_burst_time =
+ last_burst_time_ + burst_interval_;
+ Clock::time_point next_burst_time = kNever;
+ for (const SenderEntry& entry : senders_) {
+ const auto next_send_time =
+ std::min(entry.next_rtcp_send_time, entry.next_rtp_send_time);
+ if (next_send_time >= next_burst_time) {
+ continue;
+ }
+ if (next_send_time <= earliest_allowed_burst_time) {
+ next_burst_time = earliest_allowed_burst_time;
+ // No need to continue, since |next_burst_time| cannot become any earlier.
+ break;
+ }
+ next_burst_time = next_send_time;
+ }
+
+ // Schedule the alarm for the next burst time unless none of the Senders has
+ // anything to send.
+ if (next_burst_time == kNever) {
+ alarm_.Cancel();
+ } else {
+ alarm_.Schedule([this] { SendBurstOfPackets(); }, next_burst_time);
+ }
+}
+
+void SenderPacketRouter::SendBurstOfPackets() {
+ // Treat RTCP packets as "critical priority," and so there is no upper limit
+ // on the number to send. Practically, this will always be limited by the
+ // number of Senders; so, this won't be a huge number of packets.
+ const Clock::time_point burst_time = environment_->now();
+ const int num_rtcp_packets_sent = SendJustTheRtcpPackets(burst_time);
+ // Now send all the RTP packets, up to the maximum number allowed in a burst.
+ // Higher priority Senders' RTP packets are sent first.
+ const int num_rtp_packets_sent = SendJustTheRtpPackets(
+ burst_time, max_packets_per_burst_ - num_rtcp_packets_sent);
+ last_burst_time_ = burst_time;
+
+ BandwidthEstimator::OnBurstComplete(
+ num_rtcp_packets_sent + num_rtp_packets_sent, burst_time);
+
+ ScheduleNextBurst();
+}
+
+int SenderPacketRouter::SendJustTheRtcpPackets(Clock::time_point send_time) {
+ int num_sent = 0;
+ for (SenderEntry& entry : senders_) {
+ if (entry.next_rtcp_send_time > send_time) {
+ continue;
+ }
+
+ // Note: Only one RTCP packet is sent from the same Sender in the same
+ // burst. This is because RTCP packets are supposed to always contain the
+ // most up-to-date Sender state. Having multiple RTCP packets in the same
+ // burst would mean that all but the last one are old/irrelevant snapshots
+ // of Sender state, and this would just thrash/confuse the Receiver.
+ const absl::Span<uint8_t> packet =
+ entry.sender->GetRtcpPacketForImmediateSend(
+ send_time,
+ absl::Span<uint8_t>(packet_buffer_.get(), packet_buffer_size_));
+ if (!packet.empty()) {
+ environment_->SendPacket(packet);
+ entry.next_rtcp_send_time = send_time + kRtcpReportInterval;
+ ++num_sent;
+ }
+ }
+
+ return num_sent;
+}
+
+int SenderPacketRouter::SendJustTheRtpPackets(Clock::time_point send_time,
+ int num_packets_to_send) {
+ int num_sent = 0;
+ for (SenderEntry& entry : senders_) {
+ if (num_sent >= num_packets_to_send) {
+ break;
+ }
+ if (entry.next_rtp_send_time > send_time) {
+ continue;
+ }
+
+ for (; num_sent < num_packets_to_send; ++num_sent) {
+ const absl::Span<uint8_t> packet =
+ entry.sender->GetRtpPacketForImmediateSend(
+ send_time,
+ absl::Span<uint8_t>(packet_buffer_.get(), packet_buffer_size_));
+ if (packet.empty()) {
+ break;
+ }
+ environment_->SendPacket(packet);
+ }
+ entry.next_rtp_send_time = entry.sender->GetRtpResumeTime();
+ }
+
+ return num_sent;
+}
+
+namespace {
+constexpr int kBitsPerByte = 8;
+constexpr auto kOneSecondInMilliseconds =
+ duration_cast<milliseconds>(seconds(1));
+} // namespace
+
+// static
+int SenderPacketRouter::ComputeMaxPacketsPerBurst(int max_burst_bitrate,
+ int packet_size,
+ milliseconds burst_interval) {
+ OSP_DCHECK_GT(max_burst_bitrate, 0);
+ OSP_DCHECK_GT(packet_size, 0);
+ OSP_DCHECK_GT(burst_interval, milliseconds(0));
+ OSP_DCHECK_LE(burst_interval, kOneSecondInMilliseconds);
+
+ const int max_packets_per_second =
+ max_burst_bitrate / kBitsPerByte / packet_size;
+ const int bursts_per_second = kOneSecondInMilliseconds / burst_interval;
+ return std::max(max_packets_per_second / bursts_per_second, 1);
+}
+
+// static
+int SenderPacketRouter::ComputeMaxBurstBitrate(int packet_size,
+ int max_packets_per_burst,
+ milliseconds burst_interval) {
+ OSP_DCHECK_GT(packet_size, 0);
+ OSP_DCHECK_GT(max_packets_per_burst, 0);
+ OSP_DCHECK_GT(burst_interval, milliseconds(0));
+ OSP_DCHECK_LE(burst_interval, kOneSecondInMilliseconds);
+
+ const int64_t max_bits_per_burst =
+ int64_t{packet_size} * kBitsPerByte * max_packets_per_burst;
+ const int bursts_per_second = kOneSecondInMilliseconds / burst_interval;
+ return saturate_cast<int>(max_bits_per_burst * bursts_per_second);
+}
+
+SenderPacketRouter::Sender::~Sender() = default;
+
+// static
+constexpr int SenderPacketRouter::kDefaultMaxBurstBitrate;
+// static
+constexpr milliseconds SenderPacketRouter::kDefaultBurstInterval;
+// static
+constexpr Clock::time_point SenderPacketRouter::kNever;
+
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h
new file mode 100644
index 00000000000..e73e73eca3b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h
@@ -0,0 +1,196 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_STREAMING_SENDER_PACKET_ROUTER_H_
+#define CAST_STREAMING_SENDER_PACKET_ROUTER_H_
+
+#include <stdint.h>
+
+#include <chrono> // NOLINT
+#include <memory>
+#include <vector>
+
+#include "absl/types/span.h"
+#include "cast/streaming/bandwidth_estimator.h"
+#include "cast/streaming/environment.h"
+#include "cast/streaming/ssrc.h"
+#include "platform/api/time.h"
+#include "util/alarm.h"
+
+namespace openscreen {
+namespace cast {
+
+// Manages network packet transmission for one or more Senders, directing each
+// inbound packet to a specific Sender instance, pacing the transmission of
+// outbound packets, and employing network bandwidth/availability monitoring and
+// congestion control.
+//
+// Instead of just sending packets whenever they want, Senders must request
+// transmission from the SenderPacketRouter. The router then calls-back to each
+// Sender, in the near future, when it has allocated an available time slice for
+// transmission. The Sender is allowed to decide, at that exact moment, which
+// packet most needs to be sent.
+//
+// Pacing strategy: Packets are sent in bursts. This allows the platform
+// (operating system) to collect many small packets into a short-term buffer,
+// which allows for optimizations at the link layer. For example, multiple
+// packets can be sent together as one larger transmission unit, and this can be
+// critical for good performance over shared-medium networks (such as 802.11
+// WiFi). https://en.wikipedia.org/wiki/Frame-bursting
+class SenderPacketRouter : public BandwidthEstimator,
+ public Environment::PacketConsumer {
+ public:
+ class Sender {
+ public:
+ // Called to provide the Sender with what looks like a RTCP packet meant for
+ // it specifically (among other Senders) to process. |arrival_time|
+ // indicates when the packet arrived (i.e., when it was received from the
+ // platform).
+ virtual void OnReceivedRtcpPacket(Clock::time_point arrival_time,
+ absl::Span<const uint8_t> packet) = 0;
+
+ // Populates the given |buffer| with a RTCP/RTP packet that will be sent
+ // immediately. Returns the portion of |buffer| contaning the packet, or an
+ // empty Span if nothing is ready to send.
+ virtual absl::Span<uint8_t> GetRtcpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) = 0;
+ virtual absl::Span<uint8_t> GetRtpPacketForImmediateSend(
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) = 0;
+
+ // Returns the point-in-time at which RTP sending should resume, or kNever
+ // if it should be suspended until an explicit call to RequestRtpSend(). The
+ // implementation may return a value on or before "now" to indicate an
+ // immediate resume is desired.
+ virtual Clock::time_point GetRtpResumeTime() = 0;
+
+ protected:
+ virtual ~Sender();
+ };
+
+ // Constructs an instance with default burst parameters appropriate for the
+ // given |max_burst_bitrate|.
+ explicit SenderPacketRouter(Environment* environment,
+ int max_burst_bitrate = kDefaultMaxBurstBitrate);
+
+ // Constructs an instance with specific burst parameters. The maximum bitrate
+ // will be computed based on these (and Environment::GetMaxPacketSize()).
+ SenderPacketRouter(Environment* environment,
+ int max_packets_per_burst,
+ std::chrono::milliseconds burst_interval);
+
+ ~SenderPacketRouter();
+
+ int max_packet_size() const { return packet_buffer_size_; }
+ int max_burst_bitrate() const { return max_burst_bitrate_; }
+
+ // Called from a Sender constructor/destructor to register/deregister a Sender
+ // instance that processes RTP/RTCP packets from a Receiver having the given
+ // SSRC.
+ void OnSenderCreated(Ssrc receiver_ssrc, Sender* client);
+ void OnSenderDestroyed(Ssrc receiver_ssrc);
+
+ // Requests an immediate send of a RTCP packet, and then RTCP sending will
+ // repeat at regular intervals (see kRtcpSendInterval) until the Sender is
+ // de-registered.
+ void RequestRtcpSend(Ssrc receiver_ssrc);
+
+ // Requests an immediate send of a RTP packet. RTP sending will continue until
+ // the Sender stops providing packet data.
+ //
+ // See also: Sender::GetRtpResumeTime().
+ void RequestRtpSend(Ssrc receiver_ssrc);
+
+ // A reasonable default maximum bitrate for bursting. Congestion control
+ // should always be employed to limit the Senders' sustained/average outbound
+ // data volume for "fair" use of the network.
+ static constexpr int kDefaultMaxBurstBitrate = 24 << 20; // 24 megabits/sec
+
+ // The minimum amount of time between burst-sends. The methodology by which
+ // this value was determined is lost knowledge, but is likely the result of
+ // experimentation with various network and operating system configurations.
+ // This value came from the original Chrome Cast Streaming implementation.
+ static constexpr std::chrono::milliseconds kDefaultBurstInterval{10};
+
+ // A special time_point value representing "never."
+ static constexpr Clock::time_point kNever = Clock::time_point::max();
+
+ private:
+ struct SenderEntry {
+ Ssrc receiver_ssrc;
+ Sender* sender;
+ Clock::time_point next_rtcp_send_time;
+ Clock::time_point next_rtp_send_time;
+
+ // Entries are ordered by the transmission priority (high→low), as implied
+ // by their SSRC. See ssrc.h for details.
+ bool operator<(const SenderEntry& other) const {
+ return ComparePriority(receiver_ssrc, other.receiver_ssrc) < 0;
+ }
+ };
+
+ using SenderEntries = std::vector<SenderEntry>;
+
+ // Environment::PacketConsumer implementation.
+ void OnReceivedPacket(const IPEndpoint& source,
+ Clock::time_point arrival_time,
+ std::vector<uint8_t> packet) final;
+
+ // Helper to return an iterator pointing to the entry corresponding to the
+ // given |receiver_ssrc|, or "end" if not found.
+ SenderEntries::iterator FindEntry(Ssrc receiver_ssrc);
+
+ // Examine the next send time for all Senders, and decide whether to schedule
+ // a burst-send.
+ void ScheduleNextBurst();
+
+ // Performs a burst-send of packets. This is called whevener the Alarm fires.
+ void SendBurstOfPackets();
+
+ // Send an RTCP packet from each Sender that has one ready, and return the
+ // number of packets sent.
+ int SendJustTheRtcpPackets(Clock::time_point send_time);
+
+ // Send zero or more RTP packets from each Sender, up to a maximum of
+ // |num_packets_to_send|, and return the number of packets sent.
+ int SendJustTheRtpPackets(Clock::time_point send_time,
+ int num_packets_to_send);
+
+ // Returns the maximum number of packets to send in one burst, based on the
+ // given parameters.
+ static int ComputeMaxPacketsPerBurst(
+ int max_burst_bitrate,
+ int packet_size,
+ std::chrono::milliseconds burst_interval);
+
+ // Returns the maximum bitrate inferred by the given parameters.
+ static int ComputeMaxBurstBitrate(int packet_size,
+ int max_packets_per_burst,
+ std::chrono::milliseconds burst_interval);
+
+ Environment* const environment_;
+ const int packet_buffer_size_;
+ const std::unique_ptr<uint8_t[]> packet_buffer_;
+ const int max_packets_per_burst_;
+ const std::chrono::milliseconds burst_interval_;
+ const int max_burst_bitrate_;
+
+ // Schedules the task that calls back into this SenderPacketRouter at a later
+ // time to send the next burst of packets.
+ Alarm alarm_;
+
+ // The current list of Senders and their timing information. This is
+ // maintained in order of the priority implied by the Sender SSRC's.
+ SenderEntries senders_;
+
+ // The last time a burst of packets was sent. This is used to determine the
+ // next burst time.
+ Clock::time_point last_burst_time_ = Clock::time_point::min();
+};
+
+} // namespace cast
+} // namespace openscreen
+
+#endif // CAST_STREAMING_SENDER_PACKET_ROUTER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc
new file mode 100644
index 00000000000..75196bdbe68
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc
@@ -0,0 +1,616 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/sender_packet_router.h"
+
+#include "cast/streaming/constants.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/base/ip_address.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "util/big_endian.h"
+#include "util/logging.h"
+
+using std::chrono::milliseconds;
+using std::chrono::seconds;
+
+using testing::_;
+using testing::Invoke;
+using testing::Mock;
+using testing::Return;
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+const IPEndpoint kRemoteEndpoint{
+ // Use a random IPv6 address in the range reserved for "documentation
+ // purposes."
+ IPAddress::Parse("2001:db8:0d93:69c2:fd1a:49a6:a7c0:e8a6").value(), 25476};
+
+const IPEndpoint kUnexpectedEndpoint{
+ IPAddress::Parse("2001:db8:0d93:69c2:fd1a:49a6:a7c0:e8a7").value(), 25476};
+
+// Limited burst parameters to simplify unit testing.
+constexpr int kMaxPacketsPerBurst = 3;
+constexpr auto kBurstInterval = milliseconds(10);
+
+constexpr Ssrc kAudioReceiverSsrc = 2;
+constexpr Ssrc kVideoReceiverSsrc = 32;
+
+const uint8_t kGarbagePacket[] = {
+ 0x42, 0x61, 0x16, 0x17, 0x26, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69,
+ 0x6e, 0x67, 0x2f, 0x63, 0x61, 0x73, 0x74, 0x2f, 0x63, 0x6f, 0x6d, 0x70,
+ 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x72, 0x74, 0x63, 0x70, 0x5f};
+
+// clang-format off
+const uint8_t kValidAudioRtcpPacket[] = {
+ 0b10000000, // Version=2, Padding=no, ReportCount=0.
+ 201, // RTCP Packet type byte.
+ 0x00, 0x01, // Length of remainder of packet, in 32-bit words.
+ 0x00, 0x00, 0x00, 0x02, // Receiver SSRC.
+};
+
+const uint8_t kValidAudioRtpPacket[] = {
+ 0b10000000, // Version/Padding byte.
+ 96, // Payload type byte.
+ 0xbe, 0xef, // Sequence number.
+ 9, 8, 7, 6, // RTP timestamp.
+ 0, 0, 0, 2, // SSRC.
+ 0b10000000, // Is key frame, no extensions.
+ 5, // Frame ID.
+ 0xa, 0xb, // Packet ID.
+ 0xa, 0xc, // Max packet ID.
+ 0xf, 0xe, 0xd, 0xc, 0xb, 0xa, 0x9, 0x8, // Payload.
+};
+// clang-format on
+
+// Returns a copy of an |original| RTCP packet, but with its send-to SSRC
+// modified to the given |alternate_ssrc|.
+std::vector<uint8_t> MakeRtcpPacketWithAlternateReceiverSsrc(
+ absl::Span<const uint8_t> original,
+ Ssrc alternate_ssrc) {
+ constexpr int kOffsetToSsrcField = 4;
+ std::vector<uint8_t> out(original.begin(), original.end());
+ OSP_CHECK_GE(out.size(), kOffsetToSsrcField + sizeof(uint32_t));
+ WriteBigEndian(uint32_t{alternate_ssrc}, out.data() + kOffsetToSsrcField);
+ return out;
+}
+
+// Serializes the |flag| and |send_time| into the front of |buffer| so the tests
+// can make unique packets and confirm their identities after passing through
+// various components.
+absl::Span<uint8_t> MakeFakePacketWithFlag(char flag,
+ Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) {
+ const Clock::duration::rep ticks = send_time.time_since_epoch().count();
+ const auto packet_size = sizeof(ticks) + sizeof(flag);
+ buffer = buffer.subspan(0, packet_size);
+ OSP_CHECK_EQ(buffer.size(), packet_size);
+ WriteBigEndian(ticks, buffer.data());
+ buffer[sizeof(ticks)] = flag;
+ return buffer;
+}
+
+// Same as MakeFakePacketWithFlag(), but for tests that don't use the flag.
+absl::Span<uint8_t> MakeFakePacket(Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) {
+ return MakeFakePacketWithFlag('?', send_time, buffer);
+}
+
+// Returns the flag that was placed in the given |fake_packet|, or '?' if
+// unknown.
+char ParseFlag(absl::Span<const uint8_t> fake_packet) {
+ constexpr auto kFlagOffset = sizeof(Clock::duration::rep);
+ if (fake_packet.size() == (kFlagOffset + sizeof(char))) {
+ return static_cast<char>(fake_packet[kFlagOffset]);
+ }
+ return '?';
+}
+
+// Deserializes and returns the timestamp that was placed in the given |packet|
+// by MakeFakePacketWithFlag().
+Clock::time_point ParseTimestamp(absl::Span<const uint8_t> fake_packet) {
+ Clock::duration::rep ticks = 0;
+ if (fake_packet.size() >= sizeof(ticks)) {
+ ticks = ReadBigEndian<Clock::duration::rep>(fake_packet.data());
+ }
+ return Clock::time_point() + Clock::duration(ticks);
+}
+
+// Returns an empty version of |buffer|.
+absl::Span<uint8_t> ToEmptyPacketBuffer(Clock::time_point send_time,
+ absl::Span<uint8_t> buffer) {
+ return buffer.subspan(0, 0);
+}
+
+class MockEnvironment : public Environment {
+ public:
+ MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner)
+ : Environment(now_function, task_runner) {}
+
+ ~MockEnvironment() override = default;
+
+ MOCK_METHOD1(SendPacket, void(absl::Span<const uint8_t> packet));
+};
+
+class MockSender : public SenderPacketRouter::Sender {
+ public:
+ MockSender() = default;
+ ~MockSender() override = default;
+
+ MOCK_METHOD(void,
+ OnReceivedRtcpPacket,
+ (Clock::time_point arrival_time,
+ absl::Span<const uint8_t> packet),
+ (override));
+ MOCK_METHOD(absl::Span<uint8_t>,
+ GetRtcpPacketForImmediateSend,
+ (Clock::time_point send_time, absl::Span<uint8_t> buffer),
+ (override));
+ MOCK_METHOD(absl::Span<uint8_t>,
+ GetRtpPacketForImmediateSend,
+ (Clock::time_point send_time, absl::Span<uint8_t> buffer),
+ (override));
+ MOCK_METHOD(Clock::time_point, GetRtpResumeTime, (), (override));
+};
+
+class SenderPacketRouterTest : public testing::Test {
+ public:
+ SenderPacketRouterTest()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ env_(&FakeClock::now, &task_runner_),
+ router_(&env_, kMaxPacketsPerBurst, kBurstInterval) {
+ env_.set_socket_error_handler(
+ [](Error error) { ASSERT_TRUE(error.ok()) << error; });
+ }
+
+ ~SenderPacketRouterTest() override = default;
+
+ MockEnvironment* env() { return &env_; }
+ SenderPacketRouter* router() { return &router_; }
+ MockSender* audio_sender() { return &audio_sender_; }
+ MockSender* video_sender() { return &video_sender_; }
+
+ void SimulatePacketArrivedNow(const IPEndpoint& source,
+ absl::Span<const uint8_t> packet) {
+ static_cast<Environment::PacketConsumer*>(&router_)->OnReceivedPacket(
+ source, env_.now(), std::vector<uint8_t>(packet.begin(), packet.end()));
+ }
+
+ void AdvanceClockAndRunTasks(Clock::duration delta) { clock_.Advance(delta); }
+ void RunTasksUntilIdle() { task_runner_.RunTasksUntilIdle(); }
+
+ private:
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ testing::NiceMock<MockEnvironment> env_;
+ SenderPacketRouter router_;
+ testing::NiceMock<MockSender> audio_sender_;
+ testing::NiceMock<MockSender> video_sender_;
+};
+
+// Tests that the SenderPacketRouter is correctly configured from the specific
+// burst parameters that were passed to its constructor. This confirms internal
+// calculations based on these parameters.
+TEST_F(SenderPacketRouterTest, IsConfiguredFromBurstParameters) {
+ EXPECT_EQ(env()->GetMaxPacketSize(), router()->max_packet_size());
+
+ // The following lower-bound/upper-bound values were hand-calculated based on
+ // the arguments that were passed to the SenderPacketRouter constructor, and
+ // assuming a packet size anywhere from 256 bytes to one megabyte.
+ //
+ // The exact value for max_burst_bitrate() is not known here because
+ // Environment::GetMaxPacketSize() depends on the platform and network medium.
+ // To test for an exact value would require duplicating the math in
+ // SenderPacketRouter::ComputeMaxBurstBitrate() here (and then *what* would we
+ // be testing?).
+ EXPECT_LE(614400, router()->max_burst_bitrate());
+ EXPECT_GE(2147483647, router()->max_burst_bitrate());
+}
+
+TEST_F(SenderPacketRouterTest, IgnoresPacketsFromUnexpectedSources) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+ SimulatePacketArrivedNow(kUnexpectedEndpoint,
+ absl::Span<const uint8_t>(kValidAudioRtcpPacket));
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+TEST_F(SenderPacketRouterTest, IgnoresInboundPacketsContainingGarbage) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+ SimulatePacketArrivedNow(kUnexpectedEndpoint,
+ absl::Span<const uint8_t>(kGarbagePacket));
+ SimulatePacketArrivedNow(kRemoteEndpoint,
+ absl::Span<const uint8_t>(kGarbagePacket));
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+// Note: RTP packets should be ignored since it wouldn't make sense for a
+// Receiver to stream media to a Sender.
+TEST_F(SenderPacketRouterTest, IgnoresInboundRtpPackets) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+ SimulatePacketArrivedNow(kUnexpectedEndpoint,
+ absl::Span<const uint8_t>(kValidAudioRtpPacket));
+ SimulatePacketArrivedNow(kRemoteEndpoint,
+ absl::Span<const uint8_t>(kValidAudioRtpPacket));
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+TEST_F(SenderPacketRouterTest, IgnoresInboundRtcpPacketsFromUnknownReceivers) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+ const std::vector<uint8_t> rtcp_packet_not_for_me =
+ MakeRtcpPacketWithAlternateReceiverSsrc(kValidAudioRtcpPacket,
+ kAudioReceiverSsrc + 1);
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+ SimulatePacketArrivedNow(kUnexpectedEndpoint,
+ absl::Span<const uint8_t>(rtcp_packet_not_for_me));
+ SimulatePacketArrivedNow(kRemoteEndpoint,
+ absl::Span<const uint8_t>(rtcp_packet_not_for_me));
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+// Tests that the SenderPacketRouter forwards packets from Receivers to the
+// appropriate Sender.
+TEST_F(SenderPacketRouterTest, RoutesRTCPPacketsFromReceivers) {
+ EXPECT_CALL(*env(), SendPacket(_)).Times(0);
+
+ const absl::Span<const uint8_t> audio_rtcp_packet(kValidAudioRtcpPacket);
+ std::vector<uint8_t> video_rtcp_packet =
+ MakeRtcpPacketWithAlternateReceiverSsrc(audio_rtcp_packet,
+ kVideoReceiverSsrc);
+
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+
+ // It should route a valid audio RTCP packet to the audio Sender, and ignore a
+ // valid video RTCP packet (since the video Sender is not yet known to the
+ // SenderPacketRouter).
+ {
+ Clock::time_point arrival_time{};
+ std::vector<uint8_t> received_packet;
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _))
+ .WillOnce(Invoke(
+ [&](Clock::time_point when, absl::Span<const uint8_t> packet) {
+ arrival_time = when;
+ received_packet.assign(packet.begin(), packet.end());
+ }));
+ EXPECT_CALL(*video_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+
+ const Clock::time_point expected_arrival_time = env()->now();
+ SimulatePacketArrivedNow(kRemoteEndpoint, audio_rtcp_packet);
+ SimulatePacketArrivedNow(kRemoteEndpoint, video_rtcp_packet);
+
+ Mock::VerifyAndClear(audio_sender());
+ EXPECT_EQ(expected_arrival_time, arrival_time);
+ EXPECT_EQ(audio_rtcp_packet, received_packet);
+
+ Mock::VerifyAndClear(video_sender());
+ }
+
+ AdvanceClockAndRunTasks(seconds(1));
+
+ // Register the video Sender with the router. Now, confirm audio RTCP packets
+ // still go to the audio Sender and video RTCP packets go to the video Sender.
+ router()->OnSenderCreated(kVideoReceiverSsrc, video_sender());
+ {
+ Clock::time_point audio_arrival_time{}, video_arrival_time{};
+ std::vector<uint8_t> received_audio_packet, received_video_packet;
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _))
+ .WillOnce(Invoke(
+ [&](Clock::time_point when, absl::Span<const uint8_t> packet) {
+ audio_arrival_time = when;
+ received_audio_packet.assign(packet.begin(), packet.end());
+ }));
+ EXPECT_CALL(*video_sender(), OnReceivedRtcpPacket(_, _))
+ .WillOnce(Invoke(
+ [&](Clock::time_point when, absl::Span<const uint8_t> packet) {
+ video_arrival_time = when;
+ received_video_packet.assign(packet.begin(), packet.end());
+ }));
+
+ const Clock::time_point expected_audio_arrival_time = env()->now();
+ SimulatePacketArrivedNow(kRemoteEndpoint, audio_rtcp_packet);
+
+ AdvanceClockAndRunTasks(milliseconds(11));
+
+ const Clock::time_point expected_video_arrival_time = env()->now();
+ SimulatePacketArrivedNow(kRemoteEndpoint, video_rtcp_packet);
+
+ Mock::VerifyAndClear(audio_sender());
+ EXPECT_EQ(expected_audio_arrival_time, audio_arrival_time);
+ EXPECT_EQ(audio_rtcp_packet, received_audio_packet);
+
+ Mock::VerifyAndClear(video_sender());
+ EXPECT_EQ(expected_video_arrival_time, video_arrival_time);
+ EXPECT_EQ(video_rtcp_packet, received_video_packet);
+ }
+
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+ router()->OnSenderDestroyed(kVideoReceiverSsrc);
+}
+
+// Tests that the SenderPacketRouter schedules periodic RTCP packet sends,
+// starting once the Sender requests the first RTCP send.
+TEST_F(SenderPacketRouterTest, SchedulesPeriodicTransmissionOfRTCPPackets) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+
+ constexpr int kNumIterations = 5;
+
+ EXPECT_CALL(*audio_sender(), OnReceivedRtcpPacket(_, _)).Times(0);
+ EXPECT_CALL(*audio_sender(), GetRtcpPacketForImmediateSend(_, _))
+ .Times(kNumIterations)
+ .WillRepeatedly(Invoke(&MakeFakePacket));
+ EXPECT_CALL(*audio_sender(), GetRtpPacketForImmediateSend(_, _)).Times(0);
+ ON_CALL(*audio_sender(), GetRtpResumeTime())
+ .WillByDefault(Return(SenderPacketRouter::kNever));
+
+ // Capture every packet sent for analysis at the end of this test.
+ std::vector<std::vector<uint8_t>> packets_sent;
+ EXPECT_CALL(*env(), SendPacket(_))
+ .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) {
+ packets_sent.emplace_back(packet.begin(), packet.end());
+ }));
+
+ const Clock::time_point first_send_time = env()->now();
+ router()->RequestRtcpSend(kAudioReceiverSsrc);
+ RunTasksUntilIdle(); // The first RTCP packet should be sent immediately.
+ for (int i = 1; i < kNumIterations; ++i) {
+ AdvanceClockAndRunTasks(kRtcpReportInterval);
+ }
+
+ // Ensure each RTCP packet was sent and in-sequence.
+ Mock::VerifyAndClear(env());
+ ASSERT_EQ(kNumIterations, static_cast<int>(packets_sent.size()));
+ for (int i = 0; i < kNumIterations; ++i) {
+ const Clock::time_point expected_send_time =
+ first_send_time + i * kRtcpReportInterval;
+ EXPECT_EQ(expected_send_time, ParseTimestamp(packets_sent[i]));
+ }
+
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+// Tests that the SenderPacketRouter schedules RTP packet bursts from a single
+// Sender.
+TEST_F(SenderPacketRouterTest, SchedulesAndTransmitsRTPBursts) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ router()->OnSenderCreated(kVideoReceiverSsrc, video_sender());
+
+ // Capture every packet sent for analysis at the end of this test.
+ std::vector<std::vector<uint8_t>> packets_sent;
+ EXPECT_CALL(*env(), SendPacket(_))
+ .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) {
+ packets_sent.emplace_back(packet.begin(), packet.end());
+ }));
+
+ // Simulate a typical video Sender RTP at-startup sending sequence: First, at
+ // t=0ms, the Sender wants to send its large 10-packet key frame. This will
+ // require four bursts, since only 3 packets can be sent per burst.
+ //
+ // While the first frame is being sent, a smaller 4-packet frame is enqueued,
+ // and the Sender will want to start sending this immediately after the first
+ // frame. Part of this second frame will be sent in the fourth burst, and the
+ // rest in the fifth burst.
+ //
+ // After the fifth burst, the Sender will schedule a "kickstart packet" for
+ // 25ms later. However, when the SenderPacketRouter later asks the Sender for
+ // that packet, the Sender will change its mind and decide not to send
+ // anything.
+ //
+ // At t=100ms, the next frame of video is enqueued in the Sender and it
+ // requests that RTP sending resume for that. This is a small 1-packet frame.
+ const Clock::time_point start_time = env()->now();
+ int num_get_rtp_calls = 0;
+ EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _))
+ .Times(14 + 2)
+ .WillRepeatedly(
+ Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) {
+ ++num_get_rtp_calls;
+
+ // 14 packets are sent: The first through fourth bursts send three
+ // packets each, and the fifth burst sends two.
+ if (num_get_rtp_calls <= 14) {
+ return MakeFakePacket(send_time, buffer);
+ }
+
+ // 2 "done signals" are then sent: One is at the end of the fifth
+ // burst, one is for a "nothing to send" sixth burst.
+ return ToEmptyPacketBuffer(send_time, buffer);
+ }));
+ const Clock::time_point kickstart_time =
+ start_time + 4 * kBurstInterval + milliseconds(25);
+ int num_get_resume_calls = 0;
+ EXPECT_CALL(*video_sender(), GetRtpResumeTime())
+ .Times(4 + 1 + 1)
+ .WillRepeatedly(Invoke([&] {
+ ++num_get_resume_calls;
+
+ // After each of the first through fourth bursts, the Sender wants to
+ // transmit more right away.
+ if (num_get_resume_calls <= 4) {
+ return env()->now();
+ }
+
+ // After the fifth burst, the Sender requests resuming for kickstart
+ // later.
+ if (num_get_resume_calls == 5) {
+ return kickstart_time;
+ }
+
+ // After the sixth burst, the Sender pauses RTP sending indefinitely.
+ return SenderPacketRouter::kNever;
+ }));
+ router()->RequestRtpSend(kVideoReceiverSsrc);
+ // Execute first burst.
+ RunTasksUntilIdle();
+ // Execute second through fifth bursts.
+ for (int i = 1; i <= 4; ++i) {
+ AdvanceClockAndRunTasks(kBurstInterval);
+ }
+ // Execute the sixth burst at the kickstart time.
+ AdvanceClockAndRunTasks(kickstart_time - env()->now());
+ Mock::VerifyAndClear(video_sender());
+
+ // Now, resume RTP sending for one more 1-packet frame, and then pause RTP
+ // sending again.
+ EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _))
+ .WillOnce(Invoke(&MakeFakePacket)) // Frame 2, only packet.
+ .WillOnce(Invoke(&ToEmptyPacketBuffer)); // Done for now.
+ // After the seventh burst, the Sender pauses RTP sending again.
+ EXPECT_CALL(*video_sender(), GetRtpResumeTime())
+ .WillOnce(Return(SenderPacketRouter::kNever));
+ // Advance to the resume time. Nothing should happen until RequestRtpSend() is
+ // called.
+ const Clock::time_point resume_time = start_time + milliseconds(100);
+ AdvanceClockAndRunTasks(resume_time - env()->now());
+ router()->RequestRtpSend(kVideoReceiverSsrc);
+ // Execute seventh burst.
+ RunTasksUntilIdle();
+ // Run for one more second, but nothing should be happening since sending is
+ // paused.
+ AdvanceClockAndRunTasks(seconds(1));
+ Mock::VerifyAndClear(video_sender());
+
+ // Confirm 15 packets got sent and contain the expected data (which tracks
+ // when they were sent).
+ ASSERT_EQ(15, static_cast<int>(packets_sent.size()));
+ Clock::time_point expected_time;
+ int packet_idx = 0;
+ // First burst through fourth burst.
+ for (int burst_number = 0; burst_number < 4; ++burst_number) {
+ expected_time = start_time + burst_number * kBurstInterval;
+ EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++]));
+ EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++]));
+ EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++]));
+ }
+ // Fifth burst.
+ expected_time += kBurstInterval;
+ EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++]));
+ EXPECT_EQ(expected_time, ParseTimestamp(packets_sent[packet_idx++]));
+ // Seventh burst (sixth burst sent nothing).
+ EXPECT_EQ(resume_time, ParseTimestamp(packets_sent[packet_idx++]));
+
+ router()->OnSenderDestroyed(kVideoReceiverSsrc);
+}
+
+// Tests that the SenderPacketRouter schedules packet sends based on transmit
+// prority: RTCP before RTP, and the audio Sender's packets before the video
+// Sender's.
+TEST_F(SenderPacketRouterTest, SchedulesAndTransmitsAccountingForPriority) {
+ env()->set_remote_endpoint(kRemoteEndpoint);
+ ASSERT_LT(ComparePriority(kAudioReceiverSsrc, kVideoReceiverSsrc), 0);
+ router()->OnSenderCreated(kVideoReceiverSsrc, video_sender());
+ router()->OnSenderCreated(kAudioReceiverSsrc, audio_sender());
+
+ // Capture every packet sent for analysis at the end of this test.
+ std::vector<std::vector<uint8_t>> packets_sent;
+ EXPECT_CALL(*env(), SendPacket(_))
+ .WillRepeatedly(Invoke([&](absl::Span<const uint8_t> packet) {
+ packets_sent.emplace_back(packet.begin(), packet.end());
+ }));
+
+ // These indicate how often one packet will be sent from each Sender.
+ constexpr Clock::duration kAudioRtpInterval = milliseconds(10);
+ constexpr Clock::duration kVideoRtpInterval = milliseconds(33);
+
+ // Note: The priority flags used in this test ('0'..'3') indicate
+ // lowest-to-highest priority.
+ EXPECT_CALL(*audio_sender(), GetRtcpPacketForImmediateSend(_, _))
+ .WillRepeatedly(
+ Invoke([](Clock::time_point send_time, absl::Span<uint8_t> buffer) {
+ return MakeFakePacketWithFlag('3', send_time, buffer);
+ }));
+ int num_audio_get_rtp_calls = 0;
+ EXPECT_CALL(*audio_sender(), GetRtpPacketForImmediateSend(_, _))
+ .WillRepeatedly(
+ Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) {
+ // Alternate between returning a single packet and a "done for now"
+ // signal.
+ ++num_audio_get_rtp_calls;
+ if (num_audio_get_rtp_calls % 2) {
+ return MakeFakePacketWithFlag('1', send_time, buffer);
+ }
+ return buffer.subspan(0, 0);
+ }));
+ EXPECT_CALL(*video_sender(), GetRtcpPacketForImmediateSend(_, _))
+ .WillRepeatedly(
+ Invoke([](Clock::time_point send_time, absl::Span<uint8_t> buffer) {
+ return MakeFakePacketWithFlag('2', send_time, buffer);
+ }));
+ int num_video_get_rtp_calls = 0;
+ EXPECT_CALL(*video_sender(), GetRtpPacketForImmediateSend(_, _))
+ .WillRepeatedly(
+ Invoke([&](Clock::time_point send_time, absl::Span<uint8_t> buffer) {
+ // Alternate between returning a single packet and a "done for now"
+ // signal.
+ ++num_video_get_rtp_calls;
+ if (num_video_get_rtp_calls % 2) {
+ return MakeFakePacketWithFlag('0', send_time, buffer);
+ }
+ return buffer.subspan(0, 0);
+ }));
+ EXPECT_CALL(*audio_sender(), GetRtpResumeTime()).WillRepeatedly(Invoke([&] {
+ return env()->now() + kAudioRtpInterval;
+ }));
+ EXPECT_CALL(*video_sender(), GetRtpResumeTime()).WillRepeatedly(Invoke([&] {
+ return env()->now() + kVideoRtpInterval;
+ }));
+
+ // Request starting both RTCP and RTP sends for both Senders, in a random
+ // order.
+ router()->RequestRtcpSend(kVideoReceiverSsrc);
+ router()->RequestRtpSend(kAudioReceiverSsrc);
+ router()->RequestRtcpSend(kAudioReceiverSsrc);
+ router()->RequestRtpSend(kVideoReceiverSsrc);
+
+ // Run the SenderPacketRouter for 3 seconds.
+ constexpr Clock::duration kSimulationDuration = seconds(3);
+ constexpr Clock::duration kSimulationStepPeriod = milliseconds(1);
+ const Clock::time_point start_time = env()->now();
+ RunTasksUntilIdle();
+ const Clock::time_point end_time = start_time + kSimulationDuration;
+ while (env()->now() <= end_time) {
+ AdvanceClockAndRunTasks(kSimulationStepPeriod);
+ }
+
+ // Examine the packets that were actually sent, and confirm that the priority
+ // ordering was maintained.
+ ASSERT_EQ(384, static_cast<int>(packets_sent.size()));
+ // The very first packet sent should be an audio RTCP packet.
+ EXPECT_EQ('3', ParseFlag(packets_sent[0]));
+ EXPECT_EQ(start_time, ParseTimestamp(packets_sent[0]));
+ // Scan the rest, checking that packets sent in the same burst (i.e., having
+ // the same send timestamp) were sent in priority order.
+ char last_priority_flag = '3';
+ Clock::time_point last_timestamp = start_time;
+ for (int i = 1; i < static_cast<int>(packets_sent.size()) &&
+ !testing::Test::HasFailure();
+ ++i) {
+ const char priority_flag = ParseFlag(packets_sent[i]);
+ const Clock::time_point timestamp = ParseTimestamp(packets_sent[i]);
+ EXPECT_LE(last_timestamp, timestamp) << "packet[" << i << ']';
+ if (timestamp == last_timestamp) {
+ EXPECT_GT(last_priority_flag, priority_flag) << "packet[" << i << ']';
+ }
+ last_priority_flag = priority_flag;
+ last_timestamp = timestamp;
+ }
+
+ router()->OnSenderDestroyed(kVideoReceiverSsrc);
+ router()->OnSenderDestroyed(kAudioReceiverSsrc);
+}
+
+} // namespace
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc
index 56d100d8b17..5e230da45a4 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.cc
@@ -7,8 +7,8 @@
#include "cast/streaming/packet_util.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
SenderReportBuilder::SenderReportBuilder(RtcpSession* session)
: session_(session) {
@@ -52,5 +52,33 @@ std::pair<absl::Span<uint8_t>, StatusReportId> SenderReportBuilder::BuildPacket(
ToStatusReportId(ntp_timestamp));
}
-} // namespace streaming
+Clock::time_point SenderReportBuilder::GetRecentReportTime(
+ StatusReportId report_id,
+ Clock::time_point on_or_before) const {
+ // Assumption: The |report_id| is the middle 32 bits of a 64-bit NtpTimestamp.
+ static_assert(ToStatusReportId(NtpTimestamp{0x0192a3b4c5d6e7f8}) ==
+ StatusReportId{0xa3b4c5d6},
+ "FIXME: ToStatusReportId() implementation changed.");
+
+ // Compute the maximum possible NtpTimestamp. Then, use its uppermost 16 bits
+ // and the 32 bits from the report_id to produce a reconstructed NtpTimestamp.
+ const NtpTimestamp max_timestamp =
+ session_->ntp_converter().ToNtpTimestamp(on_or_before);
+ // max_timestamp: HH......
+ // report_id: LLLL
+ // ↓↓ ↙↙↙↙
+ // reconstructed: HHLLLL00
+ NtpTimestamp reconstructed = (max_timestamp & (uint64_t{0xffff} << 48)) |
+ (static_cast<uint64_t>(report_id) << 16);
+ // If the reconstructed timestamp is greater than the maximum one, rollover
+ // of the lower 48 bits occurred. Subtract one from the upper 16 bits to
+ // rectify that.
+ if (reconstructed > max_timestamp) {
+ reconstructed -= uint64_t{1} << 48;
+ }
+
+ return session_->ntp_converter().ToLocalTime(reconstructed);
+}
+
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h
index 724e7342628..c5367783a95 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_builder.h
@@ -13,9 +13,10 @@
#include "cast/streaming/rtcp_common.h"
#include "cast/streaming/rtcp_session.h"
#include "cast/streaming/rtp_defines.h"
+#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Builds RTCP packets containing one Sender Report.
class SenderReportBuilder {
@@ -31,6 +32,11 @@ class SenderReportBuilder {
const RtcpSenderReport& sender_report,
absl::Span<uint8_t> buffer) const;
+ // Returns the approximate reference time from a recently-built Sender Report,
+ // based on the given |report_id| and maximum possible reference time.
+ Clock::time_point GetRecentReportTime(StatusReportId report_id,
+ Clock::time_point on_or_before) const;
+
// The required size (in bytes) of the buffer passed to BuildPacket().
static constexpr int kRequiredBufferSize =
kRtcpCommonHeaderSize + kRtcpSenderReportSize + kRtcpReportBlockSize;
@@ -39,7 +45,7 @@ class SenderReportBuilder {
RtcpSession* const session_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_SENDER_REPORT_BUILDER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc
index 4770dd6bacd..d917486371e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.cc
@@ -7,8 +7,8 @@
#include "cast/streaming/packet_util.h"
#include "util/logging.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
SenderReportParser::SenderReportWithId::SenderReportWithId() = default;
SenderReportParser::SenderReportWithId::~SenderReportWithId() = default;
@@ -71,5 +71,5 @@ SenderReportParser::Parse(absl::Span<const uint8_t> buffer) {
return sender_report;
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h
index df7f89a0e25..3c66829c6b4 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser.h
@@ -12,8 +12,8 @@
#include "cast/streaming/rtp_defines.h"
#include "cast/streaming/rtp_time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Parses RTCP packets from a Sender to extract Sender Reports. Ignores anything
// else, since that is all a Receiver would be interested in.
@@ -45,7 +45,7 @@ class SenderReportParser {
RtpTimeTicks last_parsed_rtp_timestamp_;
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_SENDER_REPORT_PARSER_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc
index 8b9d8e74057..c79eaa6aa72 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_parser_fuzzer.cc
@@ -8,10 +8,10 @@
#include "platform/api/time.h"
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
- using cast::streaming::RtcpSenderReport;
- using cast::streaming::RtcpSession;
- using cast::streaming::SenderReportParser;
- using cast::streaming::Ssrc;
+ using openscreen::cast::RtcpSenderReport;
+ using openscreen::cast::RtcpSession;
+ using openscreen::cast::SenderReportParser;
+ using openscreen::cast::Ssrc;
constexpr Ssrc kSenderSsrcInSeedCorpus = 1;
constexpr Ssrc kReceiverSsrcInSeedCorpus = 2;
@@ -24,7 +24,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wexit-time-destructors"
static RtcpSession session(kSenderSsrcInSeedCorpus, kReceiverSsrcInSeedCorpus,
- openscreen::platform::Clock::now());
+ openscreen::Clock::now());
static SenderReportParser parser(&session);
#pragma clang diagnostic pop
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc
index bad2cf0afa4..f9ca8298d0c 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_report_unittest.cc
@@ -9,10 +9,12 @@
#include "cast/streaming/sender_report_parser.h"
#include "gtest/gtest.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
+using openscreen::operator<<;
+
constexpr Ssrc kSenderSsrc{1};
constexpr Ssrc kReceiverSsrc{2};
@@ -25,8 +27,7 @@ class SenderReportTest : public testing::Test {
}
private:
- RtcpSession session_{kSenderSsrc, kReceiverSsrc,
- openscreen::platform::Clock::now()};
+ RtcpSession session_{kSenderSsrc, kReceiverSsrc, Clock::now()};
SenderReportBuilder builder_{&session_};
SenderReportParser parser_{&session_};
};
@@ -120,7 +121,7 @@ TEST_F(SenderReportTest, BuildPackets) {
const bool with_report_block = (i == 1);
RtcpSenderReport original;
- original.reference_time = openscreen::platform::Clock::now();
+ original.reference_time = Clock::now();
original.rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(5);
original.send_packet_count = 55;
original.send_octet_count = 20044;
@@ -161,6 +162,32 @@ TEST_F(SenderReportTest, BuildPackets) {
}
}
+TEST_F(SenderReportTest, ComputesTimePointsFromReportIds) {
+ // Note: The time_points can be off by up to 16 µs because of the loss of
+ // precision caused by truncating the NtpTimestamps into StatusReportIds.
+ constexpr std::chrono::microseconds kEpsilon{16};
+
+ // Test a sampling of time points over the last 65536 seconds to confirm the
+ // rollover correction logic is working.
+ Clock::time_point on_or_before = Clock::now() + std::chrono::seconds(65536);
+ constexpr int kNumIterations = 16;
+ constexpr int kSecondsPerStep = 4096;
+ for (int i = 0; i < kNumIterations; ++i) {
+ const Clock::time_point expected_time =
+ on_or_before - std::chrono::seconds(i * kSecondsPerStep);
+ const auto report_id =
+ ToStatusReportId(ntp_converter().ToNtpTimestamp(expected_time));
+ const Clock::time_point report_time =
+ builder()->GetRecentReportTime(report_id, on_or_before);
+ EXPECT_GE(on_or_before, report_time);
+ const auto absolute_difference = (expected_time < report_time)
+ ? (report_time - expected_time)
+ : (expected_time - report_time);
+ EXPECT_LE(absolute_difference, kEpsilon)
+ << expected_time << " vs " << report_time;
+ }
+}
+
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc
new file mode 100644
index 00000000000..1c93fceed8c
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc
@@ -0,0 +1,1138 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/streaming/sender.h"
+
+#include <stdint.h>
+
+#include <algorithm>
+#include <array>
+#include <chrono> // NOLINT
+#include <limits>
+#include <map>
+#include <set>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "cast/streaming/compound_rtcp_builder.h"
+#include "cast/streaming/constants.h"
+#include "cast/streaming/encoded_frame.h"
+#include "cast/streaming/frame_collector.h"
+#include "cast/streaming/frame_crypto.h"
+#include "cast/streaming/frame_id.h"
+#include "cast/streaming/mock_environment.h"
+#include "cast/streaming/packet_util.h"
+#include "cast/streaming/rtcp_session.h"
+#include "cast/streaming/rtp_defines.h"
+#include "cast/streaming/rtp_packet_parser.h"
+#include "cast/streaming/sender_packet_router.h"
+#include "cast/streaming/sender_report_parser.h"
+#include "cast/streaming/session_config.h"
+#include "cast/streaming/ssrc.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "util/alarm.h"
+#include "util/yet_another_bit_vector.h"
+
+using std::chrono::duration_cast;
+using std::chrono::microseconds;
+using std::chrono::milliseconds;
+using std::chrono::nanoseconds;
+using std::chrono::seconds;
+
+using testing::_;
+using testing::AtLeast;
+using testing::Invoke;
+using testing::InvokeWithoutArgs;
+using testing::Mock;
+using testing::NiceMock;
+using testing::Return;
+using testing::Sequence;
+
+namespace openscreen {
+namespace cast {
+namespace {
+
+// Sender configuration.
+constexpr Ssrc kSenderSsrc = 1;
+constexpr Ssrc kReceiverSsrc = 2;
+constexpr int kRtpTimebase = 48000;
+constexpr milliseconds kTargetPlayoutDelay{400};
+constexpr auto kAesKey =
+ std::array<uint8_t, 16>{{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}};
+constexpr auto kCastIvMask =
+ std::array<uint8_t, 16>{{0xf0, 0xe0, 0xd0, 0xc0, 0xb0, 0xa0, 0x90, 0x80,
+ 0x70, 0x60, 0x50, 0x40, 0x30, 0x20, 0x10, 0x00}};
+constexpr RtpPayloadType kRtpPayloadType = RtpPayloadType::kVideoVp8;
+
+// The number of RTP ticks advanced per frame, for 100 FPS media.
+constexpr int kRtpTicksPerFrame = kRtpTimebase / 100;
+
+// The number of milliseconds advanced per frame, for 100 FPS media.
+constexpr milliseconds kFrameDuration{1000 / 100};
+static_assert(kFrameDuration < (kTargetPlayoutDelay / 10),
+ "Kickstart test assumes frame duration is far less than the "
+ "playout delay.");
+
+// An Encoded frame that also holds onto its own copy of data.
+struct EncodedFrameWithBuffer : public EncodedFrame {
+ // |EncodedFrame::data| always points inside buffer.begin()...buffer.end().
+ std::vector<uint8_t> buffer;
+};
+
+// SenderPacketRouter configuration for these tests.
+constexpr int kNumPacketsPerBurst = 20;
+constexpr milliseconds kBurstInterval{10};
+
+// An arbitrary value, subtracted from "now," to specify the reference_time on
+// frames that are about to be enqueued. This simulates that capture+encode
+// happened in the past, before Sender::EnqueueFrame() is called.
+constexpr milliseconds kCaptureDelay{11};
+
+// In some tests, the computed time values could be off a little bit due to
+// imprecision in certain wire-format timestamp types. The following macro
+// behaves just like Gtest's EXPECT_NEAR(), but works with all the time types
+// too.
+#define EXPECT_NEARLY_EQUAL(duration_a, duration_b, epsilon) \
+ if ((duration_a) >= (duration_b)) { \
+ EXPECT_LE((duration_a), (duration_b) + (epsilon)); \
+ } else { \
+ EXPECT_GE((duration_a), (duration_b) - (epsilon)); \
+ }
+
+// Simulates UDP/IPv6 traffic in one direction (from Sender→Receiver, or
+// Receiver→Sender), with a settable amount of delay.
+class SimulatedNetworkPipe {
+ public:
+ SimulatedNetworkPipe(TaskRunner* task_runner,
+ Environment::PacketConsumer* remote)
+ : task_runner_(task_runner), remote_(remote) {
+ // Create a fake IPv6 address using the "documentative purposes" prefix
+ // concatenated with the |this| pointer.
+ std::array<uint16_t, 8> hextets{};
+ hextets[0] = 0x2001;
+ hextets[1] = 0x0db8;
+ auto* const this_pointer = this;
+ static_assert(sizeof(this_pointer) <= (6 * sizeof(uint16_t)), "");
+ memcpy(&hextets[2], &this_pointer, sizeof(this_pointer));
+ local_endpoint_ = IPEndpoint{IPAddress(hextets), 2344};
+ }
+
+ const IPEndpoint& local_endpoint() const { return local_endpoint_; }
+
+ Clock::duration network_delay() const { return network_delay_; }
+ void set_network_delay(Clock::duration delay) { network_delay_ = delay; }
+
+ // The caller needs to spin the task runner before |packet| will reach the
+ // other side.
+ void StartPacketTransmission(std::vector<uint8_t> packet) {
+ task_runner_->PostTaskWithDelay(
+ [this, packet = std::move(packet)]() mutable {
+ remote_->OnReceivedPacket(local_endpoint_, FakeClock::now(),
+ std::move(packet));
+ },
+ network_delay_);
+ }
+
+ private:
+ TaskRunner* const task_runner_;
+ Environment::PacketConsumer* const remote_;
+
+ IPEndpoint local_endpoint_;
+
+ // The amount of time for the packet to transmit over this simulated network
+ // pipe. Defaults to zero to simplify the tests that don't care about delays.
+ Clock::duration network_delay_{};
+};
+
+// Processes packets from the Sender under test, allowing unit tests to set
+// expectations for parsed RTP or RTCP packets, to confirm proper behavior of
+// the Sender.
+class MockReceiver : public Environment::PacketConsumer {
+ public:
+ explicit MockReceiver(SimulatedNetworkPipe* pipe_to_sender)
+ : pipe_to_sender_(pipe_to_sender),
+ rtcp_session_(kSenderSsrc, kReceiverSsrc, FakeClock::now()),
+ sender_report_parser_(&rtcp_session_),
+ rtcp_builder_(&rtcp_session_),
+ rtp_parser_(kSenderSsrc),
+ crypto_(kAesKey, kCastIvMask) {
+ rtcp_builder_.SetPlayoutDelay(kTargetPlayoutDelay);
+ }
+
+ ~MockReceiver() override = default;
+
+ // Simulate the Receiver ACK'ing all frames up to and including the
+ // |new_checkpoint|.
+ void SetCheckpointFrame(FrameId new_checkpoint) {
+ OSP_CHECK_GE(new_checkpoint, rtcp_builder_.checkpoint_frame());
+ rtcp_builder_.SetCheckpointFrame(new_checkpoint);
+ }
+
+ // Automatically advances the checkpoint based on what is found in
+ // |complete_frames_|, returning true if the checkpoint moved forward.
+ bool AutoAdvanceCheckpoint() {
+ const FrameId old_checkpoint = rtcp_builder_.checkpoint_frame();
+ FrameId new_checkpoint = old_checkpoint;
+ for (auto it = complete_frames_.upper_bound(old_checkpoint);
+ it != complete_frames_.end(); ++it) {
+ if (it->first != new_checkpoint + 1) {
+ break;
+ }
+ ++new_checkpoint;
+ }
+ if (new_checkpoint > old_checkpoint) {
+ rtcp_builder_.SetCheckpointFrame(new_checkpoint);
+ return true;
+ }
+ return false;
+ }
+
+ void SetPictureLossIndicator(bool picture_is_lost) {
+ rtcp_builder_.SetPictureLossIndicator(picture_is_lost);
+ }
+
+ void SetReceiverReport(StatusReportId reply_for,
+ RtcpReportBlock::Delay processing_delay) {
+ RtcpReportBlock receiver_report;
+ receiver_report.ssrc = kSenderSsrc;
+ receiver_report.last_status_report_id = reply_for;
+ receiver_report.delay_since_last_report = processing_delay;
+ rtcp_builder_.IncludeReceiverReportInNextPacket(receiver_report);
+ }
+
+ void SetNacksAndAcks(std::vector<PacketNack> packet_nacks,
+ std::vector<FrameId> frame_acks) {
+ rtcp_builder_.IncludeFeedbackInNextPacket(std::move(packet_nacks),
+ std::move(frame_acks));
+ }
+
+ // Builds and sends a RTCP packet containing one or more of: checkpoint, PLI,
+ // Receiver Report, NACKs, ACKs.
+ void TransmitRtcpFeedbackPacket() {
+ uint8_t buffer[kMaxRtpPacketSizeForIpv6UdpOnEthernet];
+ const absl::Span<uint8_t> packet =
+ rtcp_builder_.BuildPacket(FakeClock::now(), buffer);
+ pipe_to_sender_->StartPacketTransmission(
+ std::vector<uint8_t>(packet.begin(), packet.end()));
+ }
+
+ // Used by tests to simulate the Receiver not seeing specific packets come in
+ // from the network (e.g., because the network dropped the packets).
+ void SetIgnoreList(std::vector<PacketNack> ignore_list) {
+ ignore_list_ = ignore_list;
+ }
+
+ // Environment::PacketConsumer implementation.
+ //
+ // Called to process a packet from the Sender, simulating basic RTP frame
+ // collection and Sender Report parsing/handling.
+ void OnReceivedPacket(const IPEndpoint& source,
+ Clock::time_point arrival_time,
+ std::vector<uint8_t> packet) override {
+ const auto type_and_ssrc = InspectPacketForRouting(packet);
+ EXPECT_NE(ApparentPacketType::UNKNOWN, type_and_ssrc.first);
+ EXPECT_EQ(kSenderSsrc, type_and_ssrc.second);
+ if (type_and_ssrc.first == ApparentPacketType::RTP) {
+ const absl::optional<RtpPacketParser::ParseResult> part_of_frame =
+ rtp_parser_.Parse(packet);
+ ASSERT_TRUE(part_of_frame);
+
+ // Return early if simulating packet drops over the network.
+ if (std::find_if(ignore_list_.begin(), ignore_list_.end(),
+ [&](const PacketNack& baddie) {
+ return (
+ baddie.frame_id == part_of_frame->frame_id &&
+ (baddie.packet_id == kAllPacketsLost ||
+ baddie.packet_id == part_of_frame->packet_id));
+ }) != ignore_list_.end()) {
+ return;
+ }
+
+ OnRtpPacket(*part_of_frame);
+ CollectRtpPacket(*part_of_frame, std::move(packet));
+ } else if (type_and_ssrc.first == ApparentPacketType::RTCP) {
+ absl::optional<SenderReportParser::SenderReportWithId> report =
+ sender_report_parser_.Parse(packet);
+ ASSERT_TRUE(report);
+ OnSenderReport(*report);
+ }
+ }
+
+ std::map<FrameId, EncodedFrameWithBuffer> TakeCompleteFrames() {
+ std::map<FrameId, EncodedFrameWithBuffer> result;
+ result.swap(complete_frames_);
+ return result;
+ }
+
+ // Tests set expectations on these mocks to monitor events of interest, and/or
+ // invoke additional behaviors.
+ MOCK_METHOD1(OnRtpPacket,
+ void(const RtpPacketParser::ParseResult& parsed_packet));
+ MOCK_METHOD1(OnFrameComplete, void(FrameId frame_id));
+ MOCK_METHOD1(OnSenderReport,
+ void(const SenderReportParser::SenderReportWithId& report));
+
+ private:
+ // Collects the individual RTP packets until a whole frame can be formed, then
+ // calls OnFrameComplete(). Ignores extra RTP packets that are no longer
+ // needed.
+ void CollectRtpPacket(const RtpPacketParser::ParseResult& part_of_frame,
+ std::vector<uint8_t> packet) {
+ const FrameId frame_id = part_of_frame.frame_id;
+ if (complete_frames_.find(frame_id) != complete_frames_.end()) {
+ return;
+ }
+ FrameCollector& collector = incomplete_frames_[frame_id];
+ collector.set_frame_id(frame_id);
+ EXPECT_TRUE(collector.CollectRtpPacket(part_of_frame, &packet));
+ if (!collector.is_complete()) {
+ return;
+ }
+ const EncryptedFrame& encrypted = collector.PeekAtAssembledFrame();
+ EncodedFrameWithBuffer* const decrypted = &complete_frames_[frame_id];
+ // Note: Not setting decrypted->reference_time here since the logic around
+ // calculating the playout time is rather complex, and is definitely outside
+ // the scope of the testing being done in this module. Instead, end-to-end
+ // testing should exist elsewhere to confirm frame play-out times with real
+ // Receivers.
+ decrypted->buffer.resize(FrameCrypto::GetPlaintextSize(encrypted));
+ decrypted->data = absl::Span<uint8_t>(decrypted->buffer);
+ crypto_.Decrypt(encrypted, decrypted);
+ incomplete_frames_.erase(frame_id);
+ OnFrameComplete(frame_id);
+ }
+
+ SimulatedNetworkPipe* const pipe_to_sender_;
+ RtcpSession rtcp_session_;
+ SenderReportParser sender_report_parser_;
+ CompoundRtcpBuilder rtcp_builder_;
+ RtpPacketParser rtp_parser_;
+ FrameCrypto crypto_;
+
+ std::vector<PacketNack> ignore_list_;
+ std::map<FrameId, FrameCollector> incomplete_frames_;
+ std::map<FrameId, EncodedFrameWithBuffer> complete_frames_;
+};
+
+class MockObserver : public Sender::Observer {
+ public:
+ MOCK_METHOD1(OnFrameCanceled, void(FrameId frame_id));
+ MOCK_METHOD0(OnPictureLost, void());
+};
+
+class SenderTest : public testing::Test {
+ public:
+ SenderTest()
+ : fake_clock_(Clock::now()),
+ task_runner_(&fake_clock_),
+ sender_environment_(&FakeClock::now, &task_runner_),
+ sender_packet_router_(&sender_environment_,
+ kNumPacketsPerBurst,
+ kBurstInterval),
+ sender_(&sender_environment_,
+ &sender_packet_router_,
+ {/* .sender_ssrc = */ kSenderSsrc,
+ /* .receiver_ssrc = */ kReceiverSsrc,
+ /* .rtp_timebase = */ kRtpTimebase,
+ /* .channels = */ 2,
+ /* .target_playout_delay = */ kTargetPlayoutDelay,
+ /* .aes_secret_key = */ kAesKey,
+ /* .aes_iv_mask = */ kCastIvMask},
+ kRtpPayloadType),
+ receiver_to_sender_pipe_(&task_runner_, &sender_packet_router_),
+ receiver_(&receiver_to_sender_pipe_),
+ sender_to_receiver_pipe_(&task_runner_, &receiver_) {
+ sender_environment_.set_socket_error_handler(
+ [](Error error) { ASSERT_TRUE(error.ok()) << error; });
+ sender_environment_.set_remote_endpoint(
+ receiver_to_sender_pipe_.local_endpoint());
+ ON_CALL(sender_environment_, SendPacket(_))
+ .WillByDefault(Invoke([this](absl::Span<const uint8_t> packet) {
+ sender_to_receiver_pipe_.StartPacketTransmission(
+ std::vector<uint8_t>(packet.begin(), packet.end()));
+ }));
+ }
+
+ ~SenderTest() override = default;
+
+ Sender* sender() { return &sender_; }
+ MockReceiver* receiver() { return &receiver_; }
+
+ void SetReceiverToSenderNetworkDelay(Clock::duration delay) {
+ receiver_to_sender_pipe_.set_network_delay(delay);
+ }
+
+ void SetSenderToReceiverNetworkDelay(Clock::duration delay) {
+ sender_to_receiver_pipe_.set_network_delay(delay);
+ }
+
+ void SimulateExecution(Clock::duration how_long = Clock::duration::zero()) {
+ fake_clock_.Advance(how_long);
+ }
+
+ static void PopulateFramePayloadBuffer(int seed,
+ int num_bytes,
+ std::vector<uint8_t>* payload) {
+ payload->clear();
+ payload->reserve(num_bytes);
+ for (int i = 0; i < num_bytes; ++i) {
+ payload->push_back(static_cast<uint8_t>(seed + i));
+ }
+ }
+
+ static void PopulateFrameWithDefaults(FrameId frame_id,
+ Clock::time_point reference_time,
+ int seed,
+ int num_payload_bytes,
+ EncodedFrameWithBuffer* frame) {
+ frame->dependency = (frame_id == FrameId::first())
+ ? EncodedFrame::KEY_FRAME
+ : EncodedFrame::DEPENDS_ON_ANOTHER;
+ frame->frame_id = frame_id;
+ frame->referenced_frame_id = frame->frame_id;
+ if (frame_id != FrameId::first()) {
+ --frame->referenced_frame_id;
+ }
+ frame->rtp_timestamp =
+ RtpTimeTicks() + (RtpTimeDelta::FromTicks(kRtpTicksPerFrame) *
+ (frame_id - FrameId::first()));
+ frame->reference_time = reference_time;
+ PopulateFramePayloadBuffer(seed, num_payload_bytes, &frame->buffer);
+ frame->data = absl::Span<uint8_t>(frame->buffer);
+ }
+
+ // Confirms that all |sent_frames| exist in |received_frames|, with identical
+ // data and metadata.
+ static void ExpectFramesReceivedCorrectly(
+ absl::Span<EncodedFrameWithBuffer> sent_frames,
+ const std::map<FrameId, EncodedFrameWithBuffer> received_frames) {
+ ASSERT_EQ(sent_frames.size(), received_frames.size());
+
+ for (const EncodedFrameWithBuffer& sent_frame : sent_frames) {
+ SCOPED_TRACE(testing::Message()
+ << "Checking sent frame " << sent_frame.frame_id);
+ const auto received_it = received_frames.find(sent_frame.frame_id);
+ if (received_it == received_frames.end()) {
+ ADD_FAILURE() << "Did not receive frame.";
+ continue;
+ }
+ const EncodedFrame& received_frame = received_it->second;
+ EXPECT_EQ(sent_frame.dependency, received_frame.dependency);
+ EXPECT_EQ(sent_frame.referenced_frame_id,
+ received_frame.referenced_frame_id);
+ EXPECT_EQ(sent_frame.rtp_timestamp, received_frame.rtp_timestamp);
+ EXPECT_TRUE(sent_frame.data == received_frame.data);
+ }
+ }
+
+ private:
+ FakeClock fake_clock_;
+ FakeTaskRunner task_runner_;
+ NiceMock<MockEnvironment> sender_environment_;
+ SenderPacketRouter sender_packet_router_;
+ Sender sender_;
+ SimulatedNetworkPipe receiver_to_sender_pipe_;
+ NiceMock<MockReceiver> receiver_;
+ SimulatedNetworkPipe sender_to_receiver_pipe_;
+};
+
+// Tests that the Sender can send EncodedFrames over an ideal network (i.e., low
+// latency, no loss), and does so without having to transmit the same packet
+// twice.
+TEST_F(SenderTest, SendsFramesEfficiently) {
+ constexpr milliseconds kOneWayNetworkDelay{1};
+ SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay);
+ SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay);
+
+ // Expect that each packet is only sent once.
+ std::set<std::pair<FrameId, FramePacketId>> received_packets;
+ EXPECT_CALL(*receiver(), OnRtpPacket(_))
+ .WillRepeatedly(
+ Invoke([&](const RtpPacketParser::ParseResult& parsed_packet) {
+ std::pair<FrameId, FramePacketId> id(parsed_packet.frame_id,
+ parsed_packet.packet_id);
+ const auto insert_result = received_packets.insert(id);
+ EXPECT_TRUE(insert_result.second)
+ << "Received duplicate packet: " << id.first << ':'
+ << static_cast<int>(id.second);
+ }));
+
+ // Simulate normal frame ACK'ing behavior.
+ ON_CALL(*receiver(), OnFrameComplete(_)).WillByDefault(InvokeWithoutArgs([&] {
+ if (receiver()->AutoAdvanceCheckpoint()) {
+ receiver()->TransmitRtcpFeedbackPacket();
+ }
+ }));
+
+ NiceMock<MockObserver> observer;
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first())).Times(1);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1)).Times(1);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2)).Times(1);
+ sender()->SetObserver(&observer);
+
+ EncodedFrameWithBuffer frames[3];
+ constexpr int kFrameDataSizes[] = {8196, 12, 1900};
+ for (int i = 0; i < 3; ++i) {
+ if (i == 0) {
+ EXPECT_TRUE(sender()->NeedsKeyFrame());
+ } else {
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+ }
+ PopulateFrameWithDefaults(FrameId::first() + i,
+ FakeClock::now() - kCaptureDelay, 0xbf - i,
+ kFrameDataSizes[i], &frames[i]);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i]));
+ SimulateExecution(kFrameDuration);
+ }
+ SimulateExecution(kTargetPlayoutDelay);
+
+ ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames());
+}
+
+// Tests that the Sender correctly computes the current in-flight media
+// duration, a backlog signal for clients.
+TEST_F(SenderTest, ComputesInFlightMediaDuration) {
+ // With no frames enqueued, the in-flight media duration should be zero.
+ EXPECT_EQ(Clock::duration::zero(),
+ sender()->GetInFlightMediaDuration(RtpTimeTicks()));
+ EXPECT_EQ(Clock::duration::zero(),
+ sender()->GetInFlightMediaDuration(
+ RtpTimeTicks() + RtpTimeDelta::FromTicks(kRtpTicksPerFrame)));
+
+ // Enqueue a frame.
+ EncodedFrameWithBuffer frame;
+ PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0,
+ 13 /* bytes */, &frame);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame));
+
+ // Now, the in-flight media duration should depend on the RTP timestamp of the
+ // next frame.
+ EXPECT_EQ(kFrameDuration, sender()->GetInFlightMediaDuration(
+ frame.rtp_timestamp +
+ RtpTimeDelta::FromTicks(kRtpTicksPerFrame)));
+ EXPECT_EQ(10 * kFrameDuration,
+ sender()->GetInFlightMediaDuration(
+ frame.rtp_timestamp +
+ RtpTimeDelta::FromTicks(10 * kRtpTicksPerFrame)));
+}
+
+// Tests that the Sender computes the maximum in-flight media duration based on
+// its analysis of current network conditions. By implication, this demonstrates
+// that the Sender is also measuring the network round-trip time.
+TEST_F(SenderTest, RespondsToNetworkLatencyChanges) {
+ // The expected maximum error in time calculations is one tick of the RTCP
+ // report block's delay type.
+ constexpr auto kEpsilon =
+ duration_cast<nanoseconds>(RtcpReportBlock::Delay(1));
+
+ // Before the Sender has the necessary information to compute the network
+ // round-trip time, GetMaxInFlightMediaDuration() will return half the target
+ // playout delay.
+ EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2,
+ sender()->GetMaxInFlightMediaDuration(), kEpsilon);
+
+ // No network is perfect. Simulate different one-way network delays.
+ constexpr milliseconds kOutboundDelay{2};
+ constexpr milliseconds kInboundDelay{4};
+ constexpr milliseconds kRoundTripDelay = kOutboundDelay + kInboundDelay;
+ SetSenderToReceiverNetworkDelay(kOutboundDelay);
+ SetReceiverToSenderNetworkDelay(kInboundDelay);
+
+ // Enqueue a frame in the Sender to start emitting periodic RTCP reports.
+ {
+ EncodedFrameWithBuffer frame;
+ PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0,
+ 1 /* byte */, &frame);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame));
+ }
+
+ // Run one network round-trip from Sender→Receiver→Sender.
+ StatusReportId sender_report_id{};
+ EXPECT_CALL(*receiver(), OnSenderReport(_))
+ .WillOnce(Invoke(
+ [&](const SenderReportParser::SenderReportWithId& sender_report) {
+ sender_report_id = sender_report.report_id;
+ }));
+ // Simulate the passage of time for the Sender Report to reach the Receiver.
+ SimulateExecution(kOutboundDelay);
+ // The Receiver should have received the Sender Report at this point.
+ Mock::VerifyAndClearExpectations(receiver());
+ ASSERT_NE(StatusReportId{}, sender_report_id);
+ // Simulate the passage of time in the Receiver doing "other tasks" before
+ // replying back to the Sender. This delay is included in the Receiver Report
+ // so that the Sender can isolate the delays caused by the network.
+ constexpr milliseconds kReceiverProcessingDelay{2};
+ SimulateExecution(kReceiverProcessingDelay);
+ // Create the Receiver Report "reply," and simulate it being sent across the
+ // network, back to the Sender.
+ receiver()->SetReceiverReport(
+ sender_report_id,
+ duration_cast<RtcpReportBlock::Delay>(kReceiverProcessingDelay));
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(kInboundDelay);
+
+ // At this point, the Sender should have computed the network round-trip time,
+ // and so GetMaxInFlightMediaDuration() will return half the target playout
+ // delay PLUS half the network round-trip time.
+ EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2 + kRoundTripDelay / 2,
+ sender()->GetMaxInFlightMediaDuration(), kEpsilon);
+
+ // Increase the outbound delay, which will increase the total round-trip time.
+ constexpr milliseconds kIncreasedOutboundDelay{6};
+ constexpr milliseconds kIncreasedRoundTripDelay =
+ kIncreasedOutboundDelay + kInboundDelay;
+ SetSenderToReceiverNetworkDelay(kIncreasedOutboundDelay);
+
+ // With increased network delay, run several more network round-trips. Expect
+ // the Sender to gradually converge towards the new network round-trip time.
+ constexpr int kNumReportIntervals = 50;
+ EXPECT_CALL(*receiver(), OnSenderReport(_))
+ .Times(kNumReportIntervals)
+ .WillRepeatedly(Invoke(
+ [&](const SenderReportParser::SenderReportWithId& sender_report) {
+ receiver()->SetReceiverReport(sender_report.report_id,
+ RtcpReportBlock::Delay::zero());
+ receiver()->TransmitRtcpFeedbackPacket();
+ }));
+ Clock::duration last_max = sender()->GetMaxInFlightMediaDuration();
+ for (int i = 0; i < kNumReportIntervals; ++i) {
+ SimulateExecution(kRtcpReportInterval);
+ const Clock::duration updated_value =
+ sender()->GetMaxInFlightMediaDuration();
+ EXPECT_LE(last_max, updated_value);
+ last_max = updated_value;
+ }
+ EXPECT_NEARLY_EQUAL(kTargetPlayoutDelay / 2 + kIncreasedRoundTripDelay / 2,
+ sender()->GetMaxInFlightMediaDuration(), kEpsilon);
+}
+
+// Tests that the Sender rejects frames if too large a span of FrameIds would be
+// in-flight at once.
+TEST_F(SenderTest, RejectsEnqueuingBeforeProtocolDesignLimit) {
+ // For this test, use 1000 FPS. This makes the frames all one millisecond
+ // apart to avoid triggering the media-duration rejection logic.
+ constexpr int kFramesPerSecond = 1000;
+ constexpr milliseconds kFrameDuration{1};
+
+ const auto OverrideRtpTimestamp = [](int frame_count, EncodedFrame* frame) {
+ const int ticks = frame_count * kRtpTimebase / kFramesPerSecond;
+ frame->rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(ticks);
+ };
+
+ // Send the absolute design-limit maximum number of frames.
+ int frame_count = 0;
+ for (; frame_count < kMaxUnackedFrames; ++frame_count) {
+ EncodedFrameWithBuffer frame;
+ PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0,
+ 13 /* bytes */, &frame);
+ OverrideRtpTimestamp(frame_count, &frame);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame));
+ SimulateExecution(kFrameDuration);
+ }
+
+ // Now, attempting to enqueue just one more frame should fail.
+ EncodedFrameWithBuffer one_frame_too_much;
+ PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0,
+ 13 /* bytes */, &one_frame_too_much);
+ OverrideRtpTimestamp(frame_count++, &one_frame_too_much);
+ EXPECT_EQ(Sender::REACHED_ID_SPAN_LIMIT,
+ sender()->EnqueueFrame(one_frame_too_much));
+ SimulateExecution(kFrameDuration);
+
+ // Now, simulate the Receiver ACKing the first frame, and enqueuing should
+ // then succeed again.
+ receiver()->SetCheckpointFrame(FrameId::first());
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ EXPECT_EQ(Sender::OK, sender()->EnqueueFrame(one_frame_too_much));
+ SimulateExecution(kFrameDuration);
+
+ // Finally, attempting to enqueue another frame should fail again.
+ EncodedFrameWithBuffer another_frame_too_much;
+ PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0,
+ 13 /* bytes */, &another_frame_too_much);
+ OverrideRtpTimestamp(frame_count++, &another_frame_too_much);
+ EXPECT_EQ(Sender::REACHED_ID_SPAN_LIMIT,
+ sender()->EnqueueFrame(another_frame_too_much));
+ SimulateExecution(kFrameDuration);
+}
+
+// Tests that the Sender rejects frames if too-long a media duration is
+// in-flight. This is the Sender's primary flow control mechanism.
+TEST_F(SenderTest, RejectsEnqueuingIfTooLongMediaDurationIsInFlight) {
+ // For this test, use 20 FPS. This makes all frames 50 ms apart, which should
+ // make it easy to trigger the media-duration rejection logic.
+ constexpr int kFramesPerSecond = 20;
+ constexpr milliseconds kFrameDuration{50};
+
+ const auto OverrideRtpTimestamp = [](int frame_count, EncodedFrame* frame) {
+ const int ticks = frame_count * kRtpTimebase / kFramesPerSecond;
+ frame->rtp_timestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(ticks);
+ };
+
+ // Enqueue frames until one is rejected because the in-flight duration would
+ // be too high.
+ EncodedFrameWithBuffer frame;
+ int frame_count = 0;
+ for (; frame_count < kMaxUnackedFrames; ++frame_count) {
+ PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0,
+ 13 /* bytes */, &frame);
+ OverrideRtpTimestamp(frame_count, &frame);
+ const auto result = sender()->EnqueueFrame(frame);
+ SimulateExecution(kFrameDuration);
+ if (result == Sender::MAX_DURATION_IN_FLIGHT) {
+ break;
+ }
+ ASSERT_EQ(Sender::OK, result);
+ }
+
+ // Now, simulate the Receiver ACKing the first frame, and enqueuing should
+ // then succeed again.
+ receiver()->SetCheckpointFrame(FrameId::first());
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ EXPECT_EQ(Sender::OK, sender()->EnqueueFrame(frame));
+ SimulateExecution(kFrameDuration);
+
+ // However, attempting to enqueue another frame should fail again.
+ EncodedFrameWithBuffer one_frame_too_much;
+ PopulateFrameWithDefaults(sender()->GetNextFrameId(), FakeClock::now(), 0,
+ 13 /* bytes */, &one_frame_too_much);
+ OverrideRtpTimestamp(++frame_count, &one_frame_too_much);
+ EXPECT_EQ(Sender::MAX_DURATION_IN_FLIGHT,
+ sender()->EnqueueFrame(one_frame_too_much));
+ SimulateExecution(kFrameDuration);
+}
+
+// Tests that the Sender propagates the Receiver's picture loss indicator to the
+// Observer::OnPictureLost(), and via calls to NeedsKeyFrame(); but only when
+// producing a key frame is absolutely necessary.
+TEST_F(SenderTest, ManagesReceiverPictureLossWorkflow) {
+ NiceMock<MockObserver> observer;
+ sender()->SetObserver(&observer);
+
+ // Send three frames...
+ EncodedFrameWithBuffer frames[6];
+ for (int i = 0; i < 3; ++i) {
+ if (i == 0) {
+ EXPECT_TRUE(sender()->NeedsKeyFrame());
+ } else {
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+ }
+ PopulateFrameWithDefaults(FrameId::first() + i,
+ FakeClock::now() - kCaptureDelay, 0,
+ 24 /* bytes */, &frames[i]);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i]));
+ SimulateExecution(kFrameDuration);
+ }
+ SimulateExecution(kTargetPlayoutDelay);
+
+ // Simulate the Receiver ACK'ing the first three frames.
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first())).Times(1);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1)).Times(1);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2)).Times(1);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ receiver()->SetCheckpointFrame(frames[2].frame_id);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ Mock::VerifyAndClearExpectations(&observer);
+
+ // Simulate something going wrong in the Receiver, and have it report picture
+ // loss to the Sender. The Sender should then propagate this to its Observer
+ // and return true when NeedsKeyFrame() is called.
+ EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0);
+ EXPECT_CALL(observer, OnPictureLost()).Times(1);
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+ receiver()->SetPictureLossIndicator(true);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_TRUE(sender()->NeedsKeyFrame());
+
+ // Send a non-key frame, and expect NeedsKeyFrame() still returns true. The
+ // Observer is not re-notified. This accounts for the case where a client's
+ // media encoder had frames in its processing pipeline before NeedsKeyFrame()
+ // began returning true.
+ EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ EncodedFrameWithBuffer& nonkey_frame = frames[3];
+ PopulateFrameWithDefaults(FrameId::first() + 3,
+ FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */,
+ &nonkey_frame);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(nonkey_frame));
+ SimulateExecution(kFrameDuration);
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_TRUE(sender()->NeedsKeyFrame());
+
+ // Now send a key frame, and expect NeedsKeyFrame() returns false. Note that
+ // the Receiver hasn't cleared the PLI condition, but the Sender knows more
+ // key frames won't be needed.
+ EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ EncodedFrameWithBuffer& recovery_frame = frames[4];
+ PopulateFrameWithDefaults(FrameId::first() + 4,
+ FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */,
+ &recovery_frame);
+ recovery_frame.dependency = EncodedFrame::KEY_FRAME;
+ recovery_frame.referenced_frame_id = recovery_frame.frame_id;
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(recovery_frame));
+ SimulateExecution(kFrameDuration);
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+
+ // Let's say the Receiver hasn't received the key frame yet, and it reports
+ // its picture loss again to the Sender. Observer::OnPictureLost() should not
+ // be called, and NeedsKeyFrame() should NOT return true, because the Sender
+ // knows the Receiver hasn't acknowledged the key frame (just sent) yet.
+ EXPECT_CALL(observer, OnFrameCanceled(nonkey_frame.frame_id)).Times(1);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ receiver()->SetCheckpointFrame(nonkey_frame.frame_id);
+ receiver()->SetPictureLossIndicator(true);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+
+ // Now, simulate the Receiver getting the key frame, but NOT recovering. This
+ // should cause Observer::OnPictureLost() to be called, and cause
+ // NeedsKeyFrame() to return true again.
+ EXPECT_CALL(observer, OnFrameCanceled(recovery_frame.frame_id)).Times(1);
+ EXPECT_CALL(observer, OnPictureLost()).Times(1);
+ receiver()->SetCheckpointFrame(recovery_frame.frame_id);
+ receiver()->SetPictureLossIndicator(true);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_TRUE(sender()->NeedsKeyFrame());
+
+ // Send another key frame, and expect NeedsKeyFrame() returns false.
+ EXPECT_CALL(observer, OnFrameCanceled(_)).Times(0);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ EncodedFrameWithBuffer& another_recovery_frame = frames[5];
+ PopulateFrameWithDefaults(FrameId::first() + 5,
+ FakeClock::now() - kCaptureDelay, 0, 24 /* bytes */,
+ &another_recovery_frame);
+ another_recovery_frame.dependency = EncodedFrame::KEY_FRAME;
+ another_recovery_frame.referenced_frame_id = another_recovery_frame.frame_id;
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(another_recovery_frame));
+ SimulateExecution(kFrameDuration);
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+
+ // Now, simulate the Receiver recovering. It will report this to the Sender,
+ // and NeedsKeyFrame() will still return false.
+ EXPECT_CALL(observer, OnFrameCanceled(another_recovery_frame.frame_id))
+ .Times(1);
+ EXPECT_CALL(observer, OnPictureLost()).Times(0);
+ receiver()->SetCheckpointFrame(another_recovery_frame.frame_id);
+ receiver()->SetPictureLossIndicator(false);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+ Mock::VerifyAndClearExpectations(&observer);
+ EXPECT_FALSE(sender()->NeedsKeyFrame());
+
+ ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames());
+}
+
+// Tests that the Receiver should get a Sender Report just before the first RTP
+// packet, and at regular intervals thereafter. The Sender Report contains the
+// lip-sync information necessary for play-out timing.
+TEST_F(SenderTest, ProvidesSenderReports) {
+ std::vector<SenderReportParser::SenderReportWithId> sender_reports;
+ Sequence packet_sequence;
+ EXPECT_CALL(*receiver(), OnSenderReport(_))
+ .InSequence(packet_sequence)
+ .WillOnce(
+ Invoke([&](const SenderReportParser::SenderReportWithId& report) {
+ sender_reports.push_back(report);
+ }))
+ .RetiresOnSaturation();
+ EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(1).InSequence(packet_sequence);
+ EXPECT_CALL(*receiver(), OnSenderReport(_))
+ .Times(3)
+ .InSequence(packet_sequence)
+ .WillRepeatedly(
+ Invoke([&](const SenderReportParser::SenderReportWithId& report) {
+ sender_reports.push_back(report);
+ }));
+
+ EncodedFrameWithBuffer frame;
+ constexpr int kFrameDataSize = 250;
+ PopulateFrameWithDefaults(FrameId::first(), FakeClock::now(), 0,
+ kFrameDataSize, &frame);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frame));
+ SimulateExecution(); // Should send one Sender Report + one RTP packet.
+ EXPECT_EQ(size_t{1}, sender_reports.size());
+
+ // Have the Receiver ACK the frame to prevent retransmitting the RTP packet.
+ receiver()->SetCheckpointFrame(FrameId::first());
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+
+ // Advance through three more reporting intervals. One Sender Report should be
+ // sent each interval, making a total of 4 reports sent.
+ constexpr auto kThreeReportIntervals = 3 * kRtcpReportInterval;
+ SimulateExecution(kThreeReportIntervals); // Three more Sender Reports.
+ ASSERT_EQ(size_t{4}, sender_reports.size());
+
+ // The first report should contain the same timestamps as the frame because
+ // the Clock did not advance. Also, its packet count and octet count fields
+ // should be zero since the report was sent before the RTP packet.
+ EXPECT_EQ(frame.reference_time, sender_reports.front().reference_time);
+ EXPECT_EQ(frame.rtp_timestamp, sender_reports.front().rtp_timestamp);
+ EXPECT_EQ(uint32_t{0}, sender_reports.front().send_packet_count);
+ EXPECT_EQ(uint32_t{0}, sender_reports.front().send_octet_count);
+
+ // The last report should contain the timestamps extrapolated into the future
+ // because the Clock did move forward. Also, the packet count and octet fields
+ // should now be non-zero because the report was sent after the RTP packet.
+ EXPECT_EQ(frame.reference_time + kThreeReportIntervals,
+ sender_reports.back().reference_time);
+ EXPECT_EQ(frame.rtp_timestamp +
+ RtpTimeDelta::FromDuration(kThreeReportIntervals, kRtpTimebase),
+ sender_reports.back().rtp_timestamp);
+ EXPECT_EQ(uint32_t{1}, sender_reports.back().send_packet_count);
+ EXPECT_EQ(uint32_t{kFrameDataSize}, sender_reports.back().send_octet_count);
+}
+
+// Tests that the Sender provides Kickstart packets whenever the Receiver may
+// not know about new frames.
+TEST_F(SenderTest, ProvidesKickstartPacketsIfReceiverDoesNotACK) {
+ // Have the Receiver move the checkpoint forward only for the first frame, and
+ // none of the later frames. This will force the Sender to eventually send a
+ // Kickstart packet.
+ ON_CALL(*receiver(), OnFrameComplete(_))
+ .WillByDefault(Invoke([&](FrameId frame_id) {
+ if (frame_id == FrameId::first()) {
+ receiver()->SetCheckpointFrame(FrameId::first());
+ receiver()->TransmitRtcpFeedbackPacket();
+ }
+ }));
+
+ // Send three frames, paced to the media.
+ EncodedFrameWithBuffer frames[3];
+ for (int i = 0; i < 3; ++i) {
+ PopulateFrameWithDefaults(FrameId::first() + i,
+ FakeClock::now() - kCaptureDelay, i,
+ 48 /* bytes */, &frames[i]);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i]));
+ SimulateExecution(kFrameDuration);
+ }
+
+ // Now, do nothing for a while. Because the Receiver isn't moving the
+ // checkpoint forward, the Sender will have sent all the RTP packets at least
+ // once, and then will start sending just Kickstart packets.
+ SimulateExecution(kTargetPlayoutDelay);
+
+ // Keep doing nothing for a while, and confirm the Sender is just sending the
+ // same Kickstart packet over and over. The Kickstart packet is supposed to be
+ // the last packet of the latest frame.
+ std::set<std::pair<FrameId, FramePacketId>> unique_received_packet_ids;
+ EXPECT_CALL(*receiver(), OnRtpPacket(_))
+ .WillRepeatedly(
+ Invoke([&](const RtpPacketParser::ParseResult& parsed_packet) {
+ unique_received_packet_ids.emplace(parsed_packet.frame_id,
+ parsed_packet.packet_id);
+ }));
+ SimulateExecution(kTargetPlayoutDelay);
+ Mock::VerifyAndClearExpectations(receiver());
+ EXPECT_EQ(size_t{1}, unique_received_packet_ids.size());
+ EXPECT_EQ(frames[2].frame_id, unique_received_packet_ids.begin()->first);
+
+ // Now, simulate the Receiver ACKing all the frames.
+ receiver()->SetCheckpointFrame(frames[2].frame_id);
+ receiver()->TransmitRtcpFeedbackPacket();
+ SimulateExecution(); // RTCP transmitted to Sender.
+
+ // With all the frames sent, the Sender should not be transmitting anything.
+ EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0);
+ SimulateExecution(10 * kTargetPlayoutDelay);
+
+ ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames());
+}
+
+// Tests that the Sender only retransmits packets specifically NACK'ed by the
+// Receiver.
+TEST_F(SenderTest, ResendsIndividuallyNackedPackets) {
+ // Populate the frame data in each frame with enough bytes to force at least
+ // three RTP packets per frame.
+ constexpr int kFrameDataSize = 3 * kMaxRtpPacketSizeForIpv6UdpOnEthernet;
+
+ // Use a 1ms network delay in each direction to make the sequence of events
+ // clearer in this test.
+ constexpr milliseconds kOneWayNetworkDelay{1};
+ SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay);
+ SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay);
+
+ // Simulate that three specific packets will be dropped by the network, one
+ // from each frame (about to be sent).
+ const std::vector<PacketNack> dropped_packets{
+ {FrameId::first(), FramePacketId{2}},
+ {FrameId::first() + 1, FramePacketId{1}},
+ {FrameId::first() + 2, FramePacketId{0}},
+ };
+ receiver()->SetIgnoreList(dropped_packets);
+
+ // Send three frames, paced to the media. The Receiver won't completely
+ // receive any of these frames due to dropped packets.
+ EXPECT_CALL(*receiver(), OnFrameComplete(_)).Times(0);
+ EncodedFrameWithBuffer frames[3];
+ for (int i = 0; i < 3; ++i) {
+ PopulateFrameWithDefaults(FrameId::first() + i,
+ FakeClock::now() - kCaptureDelay, i,
+ kFrameDataSize, &frames[i]);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i]));
+ SimulateExecution(kFrameDuration);
+ }
+ SimulateExecution(kTargetPlayoutDelay);
+ Mock::VerifyAndClearExpectations(receiver());
+ EXPECT_EQ(3, sender()->GetInFlightFrameCount());
+
+ // The Receiver NACKs the three dropped packets...
+ receiver()->SetNacksAndAcks(dropped_packets, {});
+ receiver()->TransmitRtcpFeedbackPacket();
+
+ // In the meantime, the network recovers (i.e., no more dropped packets)...
+ receiver()->SetIgnoreList({});
+
+ // The NACKs reach the Sender, and it acts on them by retransmitting.
+ SimulateExecution(kOneWayNetworkDelay);
+
+ // As each retransmitted packet arrives at the Receiver, advance the
+ // checkpoint forward to notify the Sender of frames that are now completely
+ // received. Also, confirm that only the three specifically-NACK'ed packets
+ // were retransmitted.
+ EXPECT_CALL(*receiver(), OnFrameComplete(_))
+ .Times(3)
+ .WillRepeatedly(InvokeWithoutArgs([&] {
+ if (receiver()->AutoAdvanceCheckpoint()) {
+ receiver()->TransmitRtcpFeedbackPacket();
+ }
+ }));
+ EXPECT_CALL(*receiver(), OnRtpPacket(_))
+ .Times(3)
+ .WillRepeatedly(Invoke([&](const RtpPacketParser::ParseResult& packet) {
+ EXPECT_FALSE(std::find(dropped_packets.begin(), dropped_packets.end(),
+ PacketNack{packet.frame_id, packet.packet_id}) ==
+ dropped_packets.end());
+ }));
+ SimulateExecution(kOneWayNetworkDelay);
+ Mock::VerifyAndClearExpectations(receiver());
+
+ // The Receiver checkpoint feedback(s) travel back to the Sender, and there
+ // should no longer be any frames in-flight.
+ SimulateExecution(kOneWayNetworkDelay);
+ EXPECT_EQ(0, sender()->GetInFlightFrameCount());
+
+ // The Sender should not be transmitting anything from now on since all frames
+ // are known to have been completely received.
+ EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0);
+ SimulateExecution(10 * kTargetPlayoutDelay);
+
+ ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames());
+}
+
+// Tests that the Sender retransmits an entire frame if the Receiver requests it
+// (i.e., a full frame NACK), but does not retransmit any packets for frames
+// (before or after) that have been acknowledged.
+TEST_F(SenderTest, ResendsMissingFrames) {
+ // Populate the frame data in each frame with enough bytes to force at least
+ // three RTP packets per frame.
+ constexpr int kFrameDataSize = 3 * kMaxRtpPacketSizeForIpv6UdpOnEthernet;
+
+ // Use a 1ms network delay in each direction to make the sequence of events
+ // clearer in this test.
+ constexpr milliseconds kOneWayNetworkDelay{1};
+ SetSenderToReceiverNetworkDelay(kOneWayNetworkDelay);
+ SetReceiverToSenderNetworkDelay(kOneWayNetworkDelay);
+
+ // Simulate that all of the packets for the second frame will be dropped by
+ // the network, but only the packets for that frame.
+ const std::vector<PacketNack> dropped_packets{
+ {FrameId::first() + 1, kAllPacketsLost},
+ };
+ receiver()->SetIgnoreList(dropped_packets);
+
+ NiceMock<MockObserver> observer;
+ sender()->SetObserver(&observer);
+
+ // The expectations below track the story and execute simulated Receiver
+ // responses. The Sender will have three frames enqueued by its client, and
+ // then...
+ //
+ // The first frame is received and the Receiver ACKs it by moving the
+ // checkpoint forward.
+ Sequence completion_sequence;
+ EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first()))
+ .InSequence(completion_sequence)
+ .WillOnce(InvokeWithoutArgs([&] {
+ receiver()->SetCheckpointFrame(FrameId::first());
+ receiver()->TransmitRtcpFeedbackPacket();
+ }));
+ // Since all of the packets for the second frame are being dropped, the third
+ // frame will finish next. The Receiver responds by NACKing the second frame
+ // and ACKing the third frame. The checkpoint does not move forward because
+ // the second frame has not been received yet.
+ //
+ // NETWORK CHANGE: After the third frame is received, stop dropping packets.
+ EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first() + 2))
+ .InSequence(completion_sequence)
+ .WillOnce(InvokeWithoutArgs([&] {
+ receiver()->SetNacksAndAcks(dropped_packets,
+ std::vector<FrameId>{FrameId::first() + 2});
+ receiver()->TransmitRtcpFeedbackPacket();
+ receiver()->SetIgnoreList({});
+ }));
+ // Finally, the Sender should respond to the whole-frame NACK by re-sending
+ // all of the packets for the second frame, and so the Receiver should
+ // completely receive the frame.
+ EXPECT_CALL(*receiver(), OnFrameComplete(FrameId::first() + 1))
+ .InSequence(completion_sequence)
+ .WillOnce(InvokeWithoutArgs([&] {
+ receiver()->SetCheckpointFrame(FrameId::first() + 2);
+ receiver()->TransmitRtcpFeedbackPacket();
+ }));
+
+ // From the Sender's perspective, the Receiver will ACK the first frame, then
+ // the third frame, then the second frame.
+ Sequence cancel_sequence;
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first()))
+ .Times(1)
+ .InSequence(cancel_sequence);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 2))
+ .Times(1)
+ .InSequence(cancel_sequence);
+ EXPECT_CALL(observer, OnFrameCanceled(FrameId::first() + 1))
+ .Times(1)
+ .InSequence(cancel_sequence);
+
+ // With all the expectations/sequences in-place, let 'er rip!
+ EncodedFrameWithBuffer frames[3];
+ for (int i = 0; i < 3; ++i) {
+ PopulateFrameWithDefaults(FrameId::first() + i,
+ FakeClock::now() - kCaptureDelay, i,
+ kFrameDataSize, &frames[i]);
+ ASSERT_EQ(Sender::OK, sender()->EnqueueFrame(frames[i]));
+ SimulateExecution(kFrameDuration);
+ }
+ SimulateExecution(kTargetPlayoutDelay);
+ Mock::VerifyAndClearExpectations(receiver());
+ EXPECT_EQ(0, sender()->GetInFlightFrameCount());
+
+ // The Sender should not be transmitting anything from now on since all frames
+ // are known to have been completely received.
+ EXPECT_CALL(*receiver(), OnRtpPacket(_)).Times(0);
+ SimulateExecution(10 * kTargetPlayoutDelay);
+
+ ExpectFramesReceivedCorrectly(frames, receiver()->TakeCompleteFrames());
+}
+
+} // namespace
+} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc
index a6b09f003a5..651170294c1 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc
@@ -4,21 +4,23 @@
#include "cast/streaming/session_config.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
SessionConfig::SessionConfig(Ssrc sender_ssrc,
Ssrc receiver_ssrc,
int rtp_timebase,
int channels,
+ std::chrono::milliseconds target_playout_delay,
std::array<uint8_t, 16> aes_secret_key,
std::array<uint8_t, 16> aes_iv_mask)
: sender_ssrc(sender_ssrc),
receiver_ssrc(receiver_ssrc),
rtp_timebase(rtp_timebase),
channels(channels),
+ target_playout_delay(target_playout_delay),
aes_secret_key(aes_secret_key),
aes_iv_mask(aes_iv_mask) {}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.h b/chromium/third_party/openscreen/src/cast/streaming/session_config.h
index d61efc11b35..4d611b65e2e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/session_config.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.h
@@ -6,12 +6,13 @@
#define CAST_STREAMING_SESSION_CONFIG_H_
#include <array>
+#include <chrono> // NOLINT
#include <cstdint>
#include "cast/streaming/ssrc.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
// Common streaming configuration, established from the OFFER/ANSWER exchange,
// that the Sender and Receiver are both assuming.
@@ -21,6 +22,7 @@ struct SessionConfig final {
Ssrc receiver_ssrc,
int rtp_timebase,
int channels,
+ std::chrono::milliseconds target_playout_delay,
std::array<uint8_t, 16> aes_secret_key,
std::array<uint8_t, 16> aes_iv_mask);
SessionConfig(const SessionConfig&) = default;
@@ -42,12 +44,15 @@ struct SessionConfig final {
// Number of channels. Must be 1 for video, for audio typically 2.
int channels = 1;
+ // Initial target playout delay.
+ std::chrono::milliseconds target_playout_delay;
+
// The AES-128 crypto key and initialization vector.
std::array<uint8_t, 16> aes_secret_key{};
std::array<uint8_t, 16> aes_iv_mask{};
};
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_SESSION_CONFIG_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc b/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc
index d3f2446906e..d71b806932e 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc.cc
@@ -8,8 +8,8 @@
#include "platform/api/time.h"
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
@@ -28,7 +28,7 @@ Ssrc GenerateSsrc(bool higher_priority) {
// it is light-weight and does not need to produce unguessable (nor
// crypto-secure) values.
static std::minstd_rand generator(static_cast<std::minstd_rand::result_type>(
- openscreen::platform::Clock::now().time_since_epoch().count()));
+ Clock::now().time_since_epoch().count()));
std::uniform_int_distribution<int> distribution(
higher_priority ? kHigherPriorityMin : kNormalPriorityMin,
@@ -40,5 +40,5 @@ int ComparePriority(Ssrc ssrc_a, Ssrc ssrc_b) {
return static_cast<int>(ssrc_a) - static_cast<int>(ssrc_b);
}
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc.h b/chromium/third_party/openscreen/src/cast/streaming/ssrc.h
index 0cd235665d7..b84d7ac00d7 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ssrc.h
+++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc.h
@@ -7,8 +7,8 @@
#include <stdint.h>
+namespace openscreen {
namespace cast {
-namespace streaming {
// A Synchronization Source is a 32-bit opaque identifier used in RTP packets
// for identifying the source (or recipient) of a logical sequence of encoded
@@ -33,7 +33,7 @@ Ssrc GenerateSsrc(bool higher_priority);
// ret > 0: Stream |ssrc_b| has higher priority.
int ComparePriority(Ssrc ssrc_a, Ssrc ssrc_b);
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
#endif // CAST_STREAMING_SSRC_H_
diff --git a/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc
index 29741409c64..aa9e50edfba 100644
--- a/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc
+++ b/chromium/third_party/openscreen/src/cast/streaming/ssrc_unittest.cc
@@ -9,10 +9,8 @@
#include "gtest/gtest.h"
#include "util/std_util.h"
-using openscreen::SortAndDedupeElements;
-
+namespace openscreen {
namespace cast {
-namespace streaming {
namespace {
TEST(SsrcTest, GeneratesUniqueAndPrioritizedSsrcs) {
@@ -53,5 +51,5 @@ TEST(SsrcTest, GeneratesUniqueAndPrioritizedSsrcs) {
}
} // namespace
-} // namespace streaming
} // namespace cast
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/cast/test/BUILD.gn b/chromium/third_party/openscreen/src/cast/test/BUILD.gn
new file mode 100644
index 00000000000..83950d2fd4a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/cast/test/BUILD.gn
@@ -0,0 +1,61 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import("//build_overrides/build.gni")
+
+source_set("unittests") {
+ testonly = true
+ sources = [
+ "device_auth_test.cc",
+ ]
+
+ deps = [
+ "../../testing/util",
+ "../../third_party/googletest:gmock",
+ "../../third_party/googletest:gtest",
+ "../common:channel",
+ "../common:test_helpers",
+ "../common/channel/proto:channel_proto",
+ "../receiver:channel",
+ "../receiver:test_helpers",
+ "../sender:channel",
+ ]
+}
+
+if (is_posix && !build_with_chromium) {
+ source_set("e2e_tests") {
+ testonly = true
+ sources = [
+ "cast_socket_e2e_test.cc",
+ ]
+
+ deps = [
+ "../../platform",
+ "../../third_party/abseil",
+ "../../third_party/boringssl",
+ "../../third_party/googletest:gtest",
+ "../../util",
+ "../common:certificate",
+ "../common:channel",
+ "../common:test_helpers",
+ "../receiver:channel",
+ "../receiver:test_helpers",
+ "../sender:channel",
+ ]
+ }
+
+ executable("make_crl_tests") {
+ testonly = true
+ sources = [
+ "make_crl_tests.cc",
+ ]
+
+ deps = [
+ "../../third_party/boringssl",
+ "../../util",
+ "../common:test_helpers",
+ "../common/certificate/proto:certificate_proto",
+ ]
+ }
+}
diff --git a/chromium/third_party/openscreen/src/discovery/BUILD.gn b/chromium/third_party/openscreen/src/discovery/BUILD.gn
index a7c1f3ef47f..27b5ef682e2 100644
--- a/chromium/third_party/openscreen/src/discovery/BUILD.gn
+++ b/chromium/third_party/openscreen/src/discovery/BUILD.gn
@@ -3,9 +3,31 @@
# found in the LICENSE file.
import("//build_overrides/build.gni")
+import("../testing/libfuzzer/fuzzer_test.gni")
+
+source_set("common") {
+ sources = [
+ "common/config.h",
+ "common/reporting_client.h",
+ ]
+
+ deps = [
+ "../util",
+ ]
+
+ public_deps = [
+ "../platform",
+ "../third_party/abseil",
+ ]
+}
source_set("mdns") {
sources = [
+ "mdns/mdns_domain_confirmed_provider.h",
+ "mdns/mdns_probe.cc",
+ "mdns/mdns_probe.h",
+ "mdns/mdns_probe_manager.cc",
+ "mdns/mdns_probe_manager.h",
"mdns/mdns_publisher.cc",
"mdns/mdns_publisher.h",
"mdns/mdns_querier.cc",
@@ -28,6 +50,7 @@ source_set("mdns") {
"mdns/mdns_writer.cc",
"mdns/mdns_writer.h",
"mdns/public/mdns_constants.h",
+ "mdns/public/mdns_service.cc",
"mdns/public/mdns_service.h",
]
@@ -36,14 +59,13 @@ source_set("mdns") {
]
public_deps = [
+ ":common",
"../platform",
"../third_party/abseil",
]
}
source_set("dnssd") {
- defines = []
-
sources = [
"dnssd/impl/conversion_layer.cc",
"dnssd/impl/conversion_layer.h",
@@ -69,14 +91,29 @@ source_set("dnssd") {
]
public_deps = [
+ ":common",
":mdns",
]
}
+source_set("public") {
+ sources = [
+ "public/dns_sd_service_factory.h",
+ "public/dns_sd_service_publisher.h",
+ "public/dns_sd_service_watcher.h",
+ ]
+
+ public_deps = [
+ ":common",
+ ":dnssd",
+ ]
+}
+
source_set("testing") {
testonly = true
sources = [
+ "common/testing/mock_reporting_client.h",
"dnssd/testing/fake_dns_record_factory.cc",
"mdns/testing/mdns_test_util.cc",
"mdns/testing/mdns_test_util.h",
@@ -101,14 +138,19 @@ source_set("unittests") {
"dnssd/impl/service_key_unittest.cc",
"dnssd/public/dns_sd_instance_record_unittest.cc",
"dnssd/public/dns_sd_txt_record_unittest.cc",
+ "mdns/mdns_probe_manager_unittest.cc",
+ "mdns/mdns_probe_unittest.cc",
+ "mdns/mdns_publisher_unittest.cc",
"mdns/mdns_querier_unittest.cc",
"mdns/mdns_random_unittest.cc",
"mdns/mdns_reader_unittest.cc",
"mdns/mdns_receiver_unittest.cc",
"mdns/mdns_records_unittest.cc",
+ "mdns/mdns_responder_unittest.cc",
"mdns/mdns_sender_unittest.cc",
"mdns/mdns_trackers_unittest.cc",
"mdns/mdns_writer_unittest.cc",
+ "public/dns_sd_service_watcher_unittest.cc",
]
deps = [
@@ -120,3 +162,18 @@ source_set("unittests") {
"../util",
]
}
+
+openscreen_fuzzer_test("mdns_fuzzer") {
+ sources = [
+ "mdns/mdns_reader_fuzztest.cc",
+ ]
+
+ deps = [
+ ":mdns",
+ ]
+
+ seed_corpus = "mdns/fuzzer_seeds"
+
+ # Note: 512 is the maximum size for a serialized mDNS packet.
+ libfuzzer_options = [ "max_len=512" ]
+}
diff --git a/chromium/third_party/openscreen/src/discovery/DEPS b/chromium/third_party/openscreen/src/discovery/DEPS
index b9b5cb06605..de7afcec06c 100644
--- a/chromium/third_party/openscreen/src/discovery/DEPS
+++ b/chromium/third_party/openscreen/src/discovery/DEPS
@@ -1,9 +1,9 @@
# -*- Mode: Python; -*-
include_rules = [
- '+cast',
+ # Intra-discovery dependencies must be explicit.
+ '-discovery',
- # All libcast code can use platform and cast/third_party.
- '+cast/third_party',
- '+platform'
+ # All discovery code can use discovery/common
+ '+discovery/common',
]
diff --git a/chromium/third_party/openscreen/src/discovery/common/config.h b/chromium/third_party/openscreen/src/discovery/common/config.h
new file mode 100644
index 00000000000..4b0e13e76d8
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/common/config.h
@@ -0,0 +1,56 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_COMMON_CONFIG_H_
+#define DISCOVERY_COMMON_CONFIG_H_
+
+#include "platform/base/interface_info.h"
+
+namespace openscreen {
+namespace discovery {
+
+// This struct provides parameters needed to initialize the discovery pipeline.
+struct Config {
+ /*****************************************
+ * Networking Settings
+ *****************************************/
+
+ // Network Interface on which mDNS should be run.
+ InterfaceInfo interface;
+
+ /*****************************************
+ * Publisher Settings
+ *****************************************/
+
+ // Determines whether publishing of services is enabled.
+ bool enable_publication = true;
+
+ // Number of times new mDNS records should be announced, using an exponential
+ // back off. See RFC 6762 section 8.3 for further details. Per RFC, this value
+ // is expected to be in the range of 2 to 8.
+ int new_record_announcement_count = 8;
+
+ /*****************************************
+ * Querier Settings
+ *****************************************/
+
+ // Determines whether querying is enabled.
+ bool enable_querying = true;
+
+ // Number of times new mDNS records should be announced, using an exponential
+ // back off. -1 signifies that there should be no maximum.
+ // NOTE: This is expected to be -1 in all production scenarios and only be a
+ // different value during testing.
+ int new_query_announcement_count = -1;
+
+ // Limit on the size to which the mDNS Querier Cache may grow. This is used to
+ // prevent a malicious or misbehaving mDNS client from causing the memory
+ // used by mDNS to grow in an unbounded fashion.
+ int querier_max_records_cached = 1024;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_COMMON_CONFIG_H_
diff --git a/chromium/third_party/openscreen/src/discovery/common/reporting_client.h b/chromium/third_party/openscreen/src/discovery/common/reporting_client.h
new file mode 100644
index 00000000000..68011af5ba1
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/common/reporting_client.h
@@ -0,0 +1,37 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_COMMON_REPORTING_CLIENT_H_
+#define DISCOVERY_COMMON_REPORTING_CLIENT_H_
+
+#include "platform/base/error.h"
+
+namespace openscreen {
+namespace discovery {
+
+// This class is implemented by the embedder who wishes to use discovery. The
+// discovery implementation will use this API to report back errors and metrics.
+// NOTE: All methods in the reporting client will be called from the task runner
+// thread.
+// TODO(rwkeane): Report state changes back to the caller.
+class ReportingClient {
+ public:
+ virtual ~ReportingClient() = default;
+
+ // This method is called when an error is detected by the underlying
+ // infrastructure from which recovery cannot be initiated. For example, an
+ // error binding a multicast socket.
+ virtual void OnFatalError(Error error) = 0;
+
+ // This method is called when an error is detected by the underlying
+ // infrastructure which does not prevent further functionality of the runtime.
+ // For example, a conversion failure between DnsSdInstanceRecord and the
+ // externally supplied class.
+ virtual void OnRecoverableError(Error error) = 0;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_COMMON_REPORTING_CLIENT_H_
diff --git a/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h
new file mode 100644
index 00000000000..4e3063c5d3a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h
@@ -0,0 +1,22 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_
+#define DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_
+
+#include "discovery/common/reporting_client.h"
+#include "gmock/gmock.h"
+
+namespace openscreen {
+namespace discovery {
+
+class MockReportingClient : public ReportingClient {
+ MOCK_METHOD1(OnFatalError, void(Error error));
+ MOCK_METHOD1(OnRecoverableError, void(Error error));
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_COMMON_TESTING_MOCK_REPORTING_CLIENT_H_
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS
new file mode 100644
index 00000000000..0c34a54b3d9
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS
@@ -0,0 +1,6 @@
+# -*- Mode: Python; -*-
+
+include_rules = [
+ '+discovery/dnssd/public',
+ '+discovery/mdns',
+]
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc
index ab64062197c..e15674e7b76 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc
@@ -6,11 +6,14 @@
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "discovery/dnssd/impl/constants.h"
#include "discovery/dnssd/impl/instance_key.h"
#include "discovery/dnssd/impl/service_key.h"
#include "discovery/dnssd/public/dns_sd_instance_record.h"
#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/public/mdns_constants.h"
namespace openscreen {
namespace discovery {
@@ -42,62 +45,54 @@ DomainName GetInstanceDomainName(const std::string& instance,
return DomainName{std::move(labels)};
}
+inline DomainName GetInstanceDomainName(const InstanceKey& key) {
+ return GetInstanceDomainName(key.instance_id(), key.service_id(),
+ key.domain_id());
+}
+
MdnsRecord CreatePtrRecord(const DnsSdInstanceRecord& record,
const DomainName& domain) {
PtrRecordRdata data(domain);
-
- // TTL specified by RFC 6762 section 10.
- constexpr std::chrono::seconds ttl(120);
auto outer_domain = GetPtrDomainName(record.service_id(), record.domain_id());
return MdnsRecord(std::move(outer_domain), DnsType::kPTR, DnsClass::kIN,
- RecordType::kShared, ttl, std::move(data));
+ RecordType::kShared, kPtrRecordTtl, std::move(data));
}
MdnsRecord CreateSrvRecord(const DnsSdInstanceRecord& record,
const DomainName& domain) {
uint16_t port = record.port();
-
- // TTL specified by RFC 6762 section 10.
- constexpr std::chrono::seconds ttl(120);
SrvRecordRdata data(0, 0, port, domain);
return MdnsRecord(domain, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
- ttl, std::move(data));
+ kSrvRecordTtl, std::move(data));
}
absl::optional<MdnsRecord> CreateARecord(const DnsSdInstanceRecord& record,
const DomainName& domain) {
- if (!record.address_v4().has_value()) {
+ if (!record.address_v4()) {
return absl::nullopt;
}
- // TTL specified by RFC 6762 section 10.
- constexpr std::chrono::seconds ttl(120);
- ARecordRdata data(record.address_v4().value().address);
+ ARecordRdata data(record.address_v4().address);
return MdnsRecord(domain, DnsType::kA, DnsClass::kIN, RecordType::kUnique,
- ttl, std::move(data));
+ kARecordTtl, std::move(data));
}
absl::optional<MdnsRecord> CreateAAAARecord(const DnsSdInstanceRecord& record,
const DomainName& domain) {
- if (!record.address_v6().has_value()) {
+ if (!record.address_v6()) {
return absl::nullopt;
}
- // TTL specified by RFC 6762 section 10.
- constexpr std::chrono::seconds ttl(120);
- AAAARecordRdata data(record.address_v6().value().address);
+ AAAARecordRdata data(record.address_v6().address);
return MdnsRecord(domain, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
- ttl, std::move(data));
+ kAAAARecordTtl, std::move(data));
}
MdnsRecord CreateTxtRecord(const DnsSdInstanceRecord& record,
const DomainName& domain) {
TxtRecordRdata data(record.txt().GetData());
-
- // TTL specified by RFC 6762 section 10.
- constexpr std::chrono::seconds ttl(75 * 60);
return MdnsRecord(domain, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
- ttl, std::move(data));
+ kTXTRecordTtl, std::move(data));
}
} // namespace
@@ -121,8 +116,9 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) {
std::string key = text.substr(0, index_of_eq);
std::string value = text.substr(index_of_eq + 1);
absl::Span<const uint8_t> data(
- reinterpret_cast<const uint8_t*>(value.c_str()), value.size());
- const auto set_result = txt.SetValue(key, data);
+ reinterpret_cast<const uint8_t*>(value.data()), value.size());
+ const auto set_result =
+ txt.SetValue(key, std::vector<uint8_t>(data.begin(), data.end()));
if (!set_result.ok()) {
return set_result;
}
@@ -137,10 +133,19 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) {
return txt;
}
+DomainName GetDomainName(const InstanceKey& key) {
+ return GetInstanceDomainName(key.instance_id(), key.service_id(),
+ key.domain_id());
+}
+
+DomainName GetDomainName(const MdnsRecord& record) {
+ return IsPtrRecord(record)
+ ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain()
+ : record.name();
+}
+
DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key) {
- auto domain = GetInstanceDomainName(key.instance_id(), key.service_id(),
- key.domain_id());
- return {std::move(domain), DnsType::kANY, DnsClass::kANY};
+ return {GetDomainName(key), DnsType::kANY, DnsClass::kANY};
}
DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) {
@@ -149,7 +154,12 @@ DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) {
}
bool HasValidDnsRecordAddress(const MdnsRecord& record) {
- return InstanceKey::CreateFromRecord(record).is_value();
+ return HasValidDnsRecordAddress(GetDomainName(record));
+}
+
+bool HasValidDnsRecordAddress(const DomainName& domain) {
+ return InstanceKey::TryCreate(domain).is_value() &&
+ IsInstanceValid(domain.labels()[0]);
}
bool IsPtrRecord(const MdnsRecord& record) {
@@ -157,8 +167,7 @@ bool IsPtrRecord(const MdnsRecord& record) {
}
std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstanceRecord& record) {
- auto domain = GetInstanceDomainName(record.instance_id(), record.service_id(),
- record.domain_id());
+ auto domain = GetInstanceDomainName(InstanceKey(record));
std::vector<MdnsRecord> records{CreatePtrRecord(record, domain),
CreateSrvRecord(record, domain),
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h
index d1c34c6d3d0..08eaed116f4 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.h
@@ -30,14 +30,22 @@ ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt);
bool IsPtrRecord(const MdnsRecord& record);
-// Checks that the instance, service, and domain ids in this MdnsRecord are
-// valid.
+// Checks that the instance, service, and domain ids in this instance are valid.
bool HasValidDnsRecordAddress(const MdnsRecord& record);
+bool HasValidDnsRecordAddress(const DomainName& domain);
//*** Conversions to DNS entities from DNS-SD Entities ***
+// Returns the Domain Name associated with this InstanceKey.
+DomainName GetDomainName(const InstanceKey& key);
+
+// Returns the domain name associated with this MdnsRecord. In the case of a PTR
+// record, this is the target domain, and it is the named domain in all other
+// cases.
+DomainName GetDomainName(const MdnsRecord& record);
+
// Returns the query required to get all instance information about the service
-// instances described by the provided ServiceKey.
+// instances described by the provided InstanceKey.
DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key);
// Returns the query required to get all service information that matches the
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc
index adc0abe70a2..b1f0fb245c4 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc
@@ -41,9 +41,11 @@ TEST(DnsSdConversionLayerTest, TestCreateTxtValidKeyValue) {
// EXPECT_STREQ is causing memory leaks
std::string expected = "value";
- ASSERT_EQ(record.value().GetValue("name").value().size(), expected.size());
+ ASSERT_TRUE(record.value().GetValue("name").is_value());
+ const std::vector<uint8_t>& value = record.value().GetValue("name").value();
+ ASSERT_EQ(value.size(), expected.size());
for (size_t i = 0; i < expected.size(); i++) {
- EXPECT_EQ(expected[i], record.value().GetValue("name").value()[i]);
+ EXPECT_EQ(expected[i], value[i]);
}
}
@@ -96,8 +98,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsPtr) {
DnsSdTxtRecord txt;
DnsSdInstanceRecord instance_record(
FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -129,8 +135,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsSrv) {
DnsSdTxtRecord txt;
DnsSdInstanceRecord instance_record(
FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -151,15 +161,19 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsSrv) {
const auto& rdata = absl::get<SrvRecordRdata>(it->rdata());
EXPECT_EQ(rdata.priority(), 0);
EXPECT_EQ(rdata.weight(), 0);
- EXPECT_EQ(rdata.port(), FakeDnsRecordFactory::kV4Endpoint.port);
+ EXPECT_EQ(rdata.port(), FakeDnsRecordFactory::kPortNum);
}
TEST(DnsSdConversionLayerTest, GetDnsRecordsAPresent) {
DnsSdTxtRecord txt;
DnsSdInstanceRecord instance_record(
FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -178,15 +192,18 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAPresent) {
EXPECT_EQ(it->name().labels()[3], FakeDnsRecordFactory::kDomainName);
const auto& rdata = absl::get<ARecordRdata>(it->rdata());
- EXPECT_EQ(rdata.ipv4_address(), FakeDnsRecordFactory::kV4Address);
+ EXPECT_EQ(rdata.ipv4_address(),
+ IPAddress(FakeDnsRecordFactory::kV4AddressOctets));
}
TEST(DnsSdConversionLayerTest, GetDnsRecordsANotPresent) {
DnsSdTxtRecord txt;
- DnsSdInstanceRecord instance_record(FakeDnsRecordFactory::kInstanceName,
- FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ DnsSdInstanceRecord instance_record(
+ FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -199,8 +216,12 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAAPresent) {
DnsSdTxtRecord txt;
DnsSdInstanceRecord instance_record(
FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -219,15 +240,18 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAAPresent) {
EXPECT_EQ(it->name().labels()[3], FakeDnsRecordFactory::kDomainName);
const auto& rdata = absl::get<AAAARecordRdata>(it->rdata());
- EXPECT_EQ(rdata.ipv6_address(), FakeDnsRecordFactory::kV6Address);
+ EXPECT_EQ(rdata.ipv6_address(),
+ IPAddress(FakeDnsRecordFactory::kV6AddressHextets));
}
TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAANotPresent) {
DnsSdTxtRecord txt;
- DnsSdInstanceRecord instance_record(FakeDnsRecordFactory::kInstanceName,
- FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName,
- FakeDnsRecordFactory::kV4Endpoint, txt);
+ DnsSdInstanceRecord instance_record(
+ FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
@@ -238,14 +262,17 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAANotPresent) {
TEST(DnsSdConversionLayerTest, GetDnsRecordsTxt) {
DnsSdTxtRecord txt;
- auto value =
- absl::Span<const uint8_t>(reinterpret_cast<const uint8_t*>("value"), 5);
+ std::vector<uint8_t> value{'v', 'a', 'l', 'u', 'e'};
txt.SetValue("name", value);
txt.SetFlag("boolean", true);
DnsSdInstanceRecord instance_record(
FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName,
- FakeDnsRecordFactory::kDomainName, FakeDnsRecordFactory::kV4Endpoint,
- FakeDnsRecordFactory::kV6Endpoint, txt);
+ FakeDnsRecordFactory::kDomainName,
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets),
+ FakeDnsRecordFactory::kPortNum},
+ IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets),
+ FakeDnsRecordFactory::kPortNum},
+ txt);
std::vector<MdnsRecord> records = GetDnsRecords(instance_record);
auto it = std::find_if(records.begin(), records.end(),
[](const MdnsRecord& record) {
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc
index 400fe9e7275..fb1543b6b61 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc
@@ -4,7 +4,7 @@
#include "discovery/dnssd/impl/dns_data.h"
-#include <chrono>
+#include <chrono> // NOLINT
#include "discovery/mdns/testing/mdns_test_util.h"
#include "gmock/gmock.h"
@@ -66,24 +66,26 @@ class DnsDataTesting : public DnsData {
}
};
-static const IPAddress v4_address =
- IPAddress(std::array<uint8_t, 4>{{192, 168, 0, 0}});
-static const IPAddress v6_address = IPAddress(std::array<uint8_t, 16>{
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}});
-static const std::string instance_name = "instance";
-static const std::string service_name = "_srv-name._udp";
-static const std::string domain_name = "local";
-static const InstanceKey key = {instance_name, service_name, domain_name};
-static constexpr uint16_t port_num = uint16_t{80};
+namespace {
+
+const uint8_t kV4AddressOctets[4] = {192, 168, 0, 0};
+const uint16_t kV6AddressHextets[8] = {0x0102, 0x0304, 0x0506, 0x0708,
+ 0x090a, 0x0b0c, 0x0d0e, 0x0f10};
+const char kInstanceName[] = "instance";
+const char kServiceName[] = "_srv-name._udp";
+const char kDomainName[] = "local";
+constexpr uint16_t kServicePort = uint16_t{80};
+
+} // namespace
DnsDataTesting CreateFullyPopulatedData() {
- InstanceKey instance{instance_name, service_name, domain_name};
+ InstanceKey instance{kInstanceName, kServiceName, kDomainName};
DnsDataTesting data(instance);
- DomainName target{instance_name, "_srv-name", "_udp", domain_name};
- SrvRecordRdata srv(0, 0, port_num, target);
+ DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName};
+ SrvRecordRdata srv(0, 0, kServicePort, target);
TxtRecordRdata txt = MakeTxtRecord({"name=value", "boolValue"});
- ARecordRdata a(v4_address);
- AAAARecordRdata aaaa(v6_address);
+ ARecordRdata a{IPAddress(kV4AddressOctets)};
+ AAAARecordRdata aaaa{IPAddress(kV6AddressHextets)};
data.set_srv(srv);
data.set_txt(txt);
@@ -93,8 +95,8 @@ DnsDataTesting CreateFullyPopulatedData() {
return data;
}
-MdnsRecord CreateFullyPopulatedRecord(uint16_t port = port_num) {
- DomainName target{instance_name, "_srv-name", "_udp", domain_name};
+MdnsRecord CreateFullyPopulatedRecord(uint16_t port = kServicePort) {
+ DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName};
auto type = DnsType::kSRV;
auto clazz = DnsClass::kIN;
auto record_type = RecordType::kShared;
@@ -110,15 +112,15 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataCorrectly) {
ASSERT_TRUE(result.is_value());
DnsSdInstanceRecord record = result.value();
- ASSERT_TRUE(record.address_v4().has_value());
- ASSERT_TRUE(record.address_v6().has_value());
- EXPECT_EQ(record.instance_id(), instance_name);
- EXPECT_EQ(record.service_id(), service_name);
- EXPECT_EQ(record.domain_id(), domain_name);
- EXPECT_EQ(record.address_v4().value().port, port_num);
- EXPECT_EQ(record.address_v4().value().address, v4_address);
- EXPECT_EQ(record.address_v6().value().port, port_num);
- EXPECT_EQ(record.address_v6().value().address, v6_address);
+ ASSERT_TRUE(record.address_v4());
+ ASSERT_TRUE(record.address_v6());
+ EXPECT_EQ(record.instance_id(), kInstanceName);
+ EXPECT_EQ(record.service_id(), kServiceName);
+ EXPECT_EQ(record.domain_id(), kDomainName);
+ EXPECT_EQ(record.address_v4().port, kServicePort);
+ EXPECT_EQ(record.address_v4().address, IPAddress(kV4AddressOctets));
+ EXPECT_EQ(record.address_v6().port, kServicePort);
+ EXPECT_EQ(record.address_v6().address, IPAddress(kV6AddressHextets));
EXPECT_FALSE(record.txt().IsEmpty());
}
@@ -156,10 +158,10 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataOneAddress) {
ASSERT_TRUE(result.is_value());
DnsSdInstanceRecord record = result.value();
- EXPECT_FALSE(record.address_v6().has_value());
- ASSERT_TRUE(record.address_v4().has_value());
- EXPECT_EQ(record.address_v4().value().port, port_num);
- EXPECT_EQ(record.address_v4().value().address, v4_address);
+ EXPECT_FALSE(record.address_v6());
+ ASSERT_TRUE(record.address_v4());
+ EXPECT_EQ(record.address_v4().port, kServicePort);
+ EXPECT_EQ(record.address_v4().address, IPAddress(kV4AddressOctets));
// Address v6.
data = CreateFullyPopulatedData();
@@ -168,10 +170,10 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataOneAddress) {
ASSERT_TRUE(result.is_value());
record = result.value();
- EXPECT_FALSE(record.address_v4().has_value());
- ASSERT_TRUE(record.address_v6().has_value());
- EXPECT_EQ(record.address_v6().value().port, port_num);
- EXPECT_EQ(record.address_v6().value().address, v6_address);
+ EXPECT_FALSE(record.address_v4());
+ ASSERT_TRUE(record.address_v6());
+ EXPECT_EQ(record.address_v6().port, kServicePort);
+ EXPECT_EQ(record.address_v6().address, IPAddress(kV6AddressHextets));
}
TEST(DnsSdDnsDataTests, TestConvertDnsDataBadTxt) {
@@ -183,19 +185,19 @@ TEST(DnsSdDnsDataTests, TestConvertDnsDataBadTxt) {
// ApplyDataRecordChange tests.
TEST(DnsSdDnsDataTests, TestApplyRecordChanges) {
- MdnsRecord record = CreateFullyPopulatedRecord(port_num);
- InstanceKey instance{instance_name, service_name, domain_name};
+ MdnsRecord record = CreateFullyPopulatedRecord(kServicePort);
+ InstanceKey instance{kInstanceName, kServiceName, kDomainName};
DnsDataTesting data(instance);
EXPECT_TRUE(
data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok());
ASSERT_TRUE(data.srv().has_value());
- EXPECT_EQ(data.srv().value().port(), port_num);
+ EXPECT_EQ(data.srv().value().port(), kServicePort);
record = CreateFullyPopulatedRecord(234);
EXPECT_FALSE(
data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok());
ASSERT_TRUE(data.srv().has_value());
- EXPECT_EQ(data.srv().value().port(), port_num);
+ EXPECT_EQ(data.srv().value().port(), kServicePort);
record = CreateFullyPopulatedRecord(345);
EXPECT_TRUE(
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc
index 791faaaf7eb..5995768b2a1 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc
@@ -8,25 +8,31 @@
#include "absl/strings/str_split.h"
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/impl/service_key.h"
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
#include "discovery/mdns/mdns_records.h"
#include "discovery/mdns/public/mdns_constants.h"
namespace openscreen {
namespace discovery {
-InstanceKey::InstanceKey(const MdnsRecord& record) {
- ErrorOr<InstanceKey> key = CreateFromRecord(record);
- OSP_DCHECK(key.is_value());
- *this = std::move(key.value());
+InstanceKey::InstanceKey(const MdnsRecord& record)
+ : InstanceKey(GetDomainName(record)) {}
+
+InstanceKey::InstanceKey(const DomainName& domain)
+ : ServiceKey(domain), instance_id_(domain.labels()[0]) {
+ OSP_DCHECK(IsInstanceValid(instance_id_));
}
+InstanceKey::InstanceKey(const DnsSdInstanceRecord& record)
+ : InstanceKey(record.instance_id(),
+ record.service_id(),
+ record.domain_id()) {}
+
InstanceKey::InstanceKey(absl::string_view instance,
absl::string_view service,
absl::string_view domain)
- : instance_id_(instance), service_id_(service), domain_id_(domain) {
+ : ServiceKey(service, domain), instance_id_(instance) {
OSP_DCHECK(IsInstanceValid(instance_id_));
- OSP_DCHECK(IsServiceValid(service_id_));
- OSP_DCHECK(IsDomainValid(domain_id_));
}
InstanceKey::InstanceKey(const InstanceKey& other) = default;
@@ -35,41 +41,5 @@ InstanceKey::InstanceKey(InstanceKey&& other) = default;
InstanceKey& InstanceKey::operator=(const InstanceKey& rhs) = default;
InstanceKey& InstanceKey::operator=(InstanceKey&& rhs) = default;
-bool InstanceKey::IsInstanceOf(const ServiceKey& service_key) const {
- return service_id_ == service_key.service_id() &&
- domain_id_ == service_key.domain_id();
-}
-
-ErrorOr<InstanceKey> InstanceKey::CreateFromRecord(const MdnsRecord& record) {
- const DomainName& names =
- IsPtrRecord(record)
- ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain()
- : record.name();
-
- if (names.labels().size() < 4) {
- return Error::Code::kParameterInvalid;
- }
-
- auto it = names.labels().begin();
- std::string instance_id = *it++;
- if (!IsInstanceValid(instance_id)) {
- return Error::Code::kParameterInvalid;
- }
-
- std::string service_name = *it++;
- std::string protocol = *it++;
- std::string service_id = service_name.append(".").append(protocol);
- if (!IsServiceValid(service_id)) {
- return Error::Code::kParameterInvalid;
- }
-
- std::string domain_id = absl::StrJoin(it, names.labels().end(), ".");
- if (!IsDomainValid(domain_id)) {
- return Error::Code::kParameterInvalid;
- }
-
- return InstanceKey(instance_id, service_id, domain_id);
-}
-
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h
index 36f4731285f..311b6f72e8a 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h
@@ -9,27 +9,32 @@
#include <utility>
#include "absl/strings/string_view.h"
-#include "platform/base/error.h"
+#include "discovery/dnssd/impl/service_key.h"
namespace openscreen {
namespace discovery {
+class DnsSdInstanceRecord;
+class DomainName;
class MdnsRecord;
class ServiceKey;
// This class is intended to be used as the key of a std::unordered_map or
// std::map when referencing data related to a specific service instance.
-class InstanceKey {
+class InstanceKey : public ServiceKey {
public:
// NOTE: The record provided must have valid instance, service, and domain
// labels.
explicit InstanceKey(const MdnsRecord& record);
+ explicit InstanceKey(const DomainName& domain);
+ explicit InstanceKey(const DnsSdInstanceRecord& record);
// NOTE: The provided parameters must be valid instance, service and domain
// ids.
InstanceKey(absl::string_view instance,
absl::string_view service,
absl::string_view domain);
+
InstanceKey(const InstanceKey& other);
InstanceKey(InstanceKey&& other);
@@ -37,43 +42,29 @@ class InstanceKey {
InstanceKey& operator=(InstanceKey&& rhs);
const std::string& instance_id() const { return instance_id_; }
- const std::string& service_id() const { return service_id_; }
- const std::string& domain_id() const { return domain_id_; }
-
- // Represents whether this InstanceKey is an instance of the service provided.
- bool IsInstanceOf(const ServiceKey& service_key) const;
private:
- static ErrorOr<InstanceKey> CreateFromRecord(const MdnsRecord& record);
-
std::string instance_id_;
- std::string service_id_;
- std::string domain_id_;
template <typename H>
friend H AbslHashValue(H h, const InstanceKey& key);
friend bool operator<(const InstanceKey& lhs, const InstanceKey& rhs);
-
- // Validation method which needs the same code as CreateFromRecord(). Use a
- // friend declaration to avoid duplicating this code while still keeping the
- // factory private.
- friend bool HasValidDnsRecordAddress(const MdnsRecord& record);
};
template <typename H>
H AbslHashValue(H h, const InstanceKey& key) {
- return H::combine(std::move(h), key.service_id_, key.domain_id_,
+ return H::combine(std::move(h), key.service_id(), key.domain_id(),
key.instance_id_);
}
inline bool operator<(const InstanceKey& lhs, const InstanceKey& rhs) {
- int comp = lhs.domain_id_.compare(rhs.domain_id_);
+ int comp = lhs.domain_id().compare(rhs.domain_id());
if (comp != 0) {
return comp < 0;
}
- comp = lhs.service_id_.compare(rhs.service_id_);
+ comp = lhs.service_id().compare(rhs.service_id());
if (comp != 0) {
return comp < 0;
}
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc
index 25dd052bb4b..29e199817b1 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc
@@ -45,20 +45,20 @@ TEST(DnsSdInstanceKeyTest, TestInstanceKeyEquals) {
TEST(DnsSdInstanceKeyTest, TestIsInstanceOf) {
ServiceKey ptr("_service._udp", "domain");
InstanceKey svc("instance", "_service._udp", "domain");
- EXPECT_TRUE(svc.IsInstanceOf(ptr));
+ EXPECT_EQ(svc, ptr);
svc = InstanceKey("other id", "_service._udp", "domain");
- EXPECT_TRUE(svc.IsInstanceOf(ptr));
+ EXPECT_EQ(svc, ptr);
svc = InstanceKey("instance", "_service._udp", "domain2");
- EXPECT_FALSE(svc.IsInstanceOf(ptr));
+ EXPECT_FALSE(svc == ptr);
ptr = ServiceKey("_service._udp", "domain2");
- EXPECT_TRUE(svc.IsInstanceOf(ptr));
+ EXPECT_EQ(svc, ptr);
svc = InstanceKey("instance", "_service2._udp", "domain");
- EXPECT_FALSE(svc.IsInstanceOf(ptr));
+ EXPECT_NE(svc, ptr);
ptr = ServiceKey("_service2._udp", "domain");
- EXPECT_TRUE(svc.IsInstanceOf(ptr));
+ EXPECT_EQ(svc, ptr);
}
TEST(DnsSdInstanceKeyTest, InstanceKeyInMap) {
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc
index c50c31d25d0..e8bdb9ea949 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc
@@ -4,35 +4,219 @@
#include "discovery/dnssd/impl/publisher_impl.h"
+#include <map>
+#include <utility>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "discovery/common/reporting_client.h"
+#include "discovery/dnssd/impl/conversion_layer.h"
+#include "discovery/dnssd/impl/instance_key.h"
+#include "discovery/mdns/public/mdns_constants.h"
+#include "platform/api/task_runner.h"
#include "platform/base/error.h"
namespace openscreen {
namespace discovery {
+namespace {
-PublisherImpl::PublisherImpl(MdnsService* publisher)
- : mdns_publisher_(publisher) {}
+DnsSdInstanceRecord UpdateDomain(const DomainName& domain,
+ const DnsSdInstanceRecord& record) {
+ InstanceKey key(domain);
+ const IPEndpoint& v4 = record.address_v4();
+ const IPEndpoint& v6 = record.address_v6();
+ if (v4 && v6) {
+ return DnsSdInstanceRecord(key.instance_id(), key.service_id(),
+ key.domain_id(), v4, v6, record.txt());
+ } else {
+ const IPEndpoint& endpoint = v4 ? v4 : v6;
+ return DnsSdInstanceRecord(key.instance_id(), key.service_id(),
+ key.domain_id(), endpoint, record.txt());
+ }
+}
+
+template <typename T>
+inline typename std::map<DnsSdInstanceRecord, T>::iterator FindKey(
+ std::map<DnsSdInstanceRecord, T>* records,
+ const InstanceKey& key) {
+ return std::find_if(records->begin(), records->end(),
+ [&key](const std::pair<DnsSdInstanceRecord, T>& pair) {
+ return key == InstanceKey(pair.first);
+ });
+}
+
+template <typename T>
+int EraseRecordsWithServiceId(std::map<DnsSdInstanceRecord, T>* records,
+ const std::string& service_id) {
+ int removed_count = 0;
+ for (auto it = records->begin(); it != records->end();) {
+ if (it->first.service_id() == service_id) {
+ removed_count++;
+ it = records->erase(it);
+ } else {
+ it++;
+ }
+ }
+
+ return removed_count;
+}
+
+} // namespace
+
+PublisherImpl::PublisherImpl(MdnsService* publisher,
+ ReportingClient* reporting_client,
+ TaskRunner* task_runner)
+ : mdns_publisher_(publisher),
+ reporting_client_(reporting_client),
+ task_runner_(task_runner) {
+ OSP_DCHECK(mdns_publisher_);
+ OSP_DCHECK(reporting_client_);
+ OSP_DCHECK(task_runner_);
+}
PublisherImpl::~PublisherImpl() = default;
-Error PublisherImpl::Register(const DnsSdInstanceRecord& record) {
- if (std::find(published_records_.begin(), published_records_.end(), record) !=
- published_records_.end()) {
+Error PublisherImpl::Register(const DnsSdInstanceRecord& record,
+ Client* client) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+ OSP_DCHECK(client != nullptr);
+
+ if (published_records_.find(record) != published_records_.end()) {
return Error::Code::kItemAlreadyExists;
+ } else if (pending_records_.find(record) != pending_records_.end()) {
+ return Error::Code::kOperationInProgress;
}
- published_records_.push_back(record);
- for (const auto& mdns_record : GetDnsRecords(record)) {
- mdns_publisher_->RegisterRecord(mdns_record);
+ InstanceKey key(record);
+ IPEndpoint endpoint =
+ record.address_v4() ? record.address_v4() : record.address_v6();
+ pending_records_.emplace(record, client);
+
+ OSP_DVLOG << "Registering instance '" << record.instance_id() << "'";
+
+ return mdns_publisher_->StartProbe(this, GetDomainName(key),
+ endpoint.address);
+}
+
+Error PublisherImpl::UpdateRegistration(const DnsSdInstanceRecord& record) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ // Check if the record is still pending publication.
+ auto it = FindKey(&pending_records_, InstanceKey(record));
+
+ OSP_DVLOG << "Updating instance '" << record.instance_id() << "'";
+
+ // If it is a pending record, update it. Else, try to update a published
+ // record.
+ if (it != pending_records_.end()) {
+ // The instance, service, and domain ids have not changed, so only the
+ // remaining data needs to change. The ongoing probe does not need to be
+ // modified.
+ Client* const client = it->second;
+ pending_records_.erase(it);
+ pending_records_.emplace(record, client);
+ return Error::None();
+ } else {
+ return UpdatePublishedRegistration(record);
}
- return Error::None();
}
-size_t PublisherImpl::DeregisterAll(absl::string_view service) {
- size_t removed_count = 0;
+Error PublisherImpl::UpdatePublishedRegistration(
+ const DnsSdInstanceRecord& record) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ auto published_record_it = FindKey(&published_records_, InstanceKey(record));
+
+ // Check preconditions called out in header. Specifically, the updated record
+ // must be making changes to an already published record.
+ if (published_record_it == published_records_.end() ||
+ published_record_it->first == record) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ // Get all records which have changed. By design, there an only be one record
+ // of each DnsType, so use that here to simplify this step.
+ // First in each pair is the old records, second is the new record.
+ std::map<DnsType,
+ std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>>
+ changed_records;
+ const std::vector<MdnsRecord> old_records =
+ GetDnsRecords(published_record_it->second);
+ const DnsSdInstanceRecord updated_record = UpdateDomain(
+ GetDomainName(InstanceKey(published_record_it->second)), record);
+ const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_record);
+
+ // Populate the first part of each pair in |changed_records|.
+ for (size_t i = 0; i < old_records.size(); i++) {
+ const auto key = old_records[i].dns_type();
+ OSP_DCHECK(changed_records.find(key) == changed_records.end());
+ auto value = std::make_pair(std::move(old_records[i]), absl::nullopt);
+ changed_records.emplace(key, std::move(value));
+ }
+
+ // Populate the second part of each pair in |changed_records|.
+ for (size_t i = 0; i < new_records.size(); i++) {
+ const auto key = new_records[i].dns_type();
+ auto find_it = changed_records.find(key);
+ if (find_it == changed_records.end()) {
+ std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value(
+ absl::nullopt, std::move(new_records[i]));
+ changed_records.emplace(key, std::move(value));
+ } else {
+ find_it->second.second = std::move(new_records[i]);
+ }
+ }
+
+ // Apply changes called out in |changed_records|.
+ // TODO(crbug.com/openscreen/114): Trace each below call so multiple errors
+ // can be seen.
+ Error total_result = Error::None();
+ for (const auto& pair : changed_records) {
+ OSP_DCHECK(pair.second.first != absl::nullopt ||
+ pair.second.second != absl::nullopt);
+ if (pair.second.first == absl::nullopt) {
+ auto error = mdns_publisher_->RegisterRecord(pair.second.second.value());
+ if (!error.ok()) {
+ total_result = error;
+ }
+ } else if (pair.second.second == absl::nullopt) {
+ auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value());
+ if (!error.ok()) {
+ total_result = error;
+ }
+ } else if (pair.second.first.value() != pair.second.second.value()) {
+ auto error = mdns_publisher_->UpdateRegisteredRecord(
+ pair.second.first.value(), pair.second.second.value());
+ if (!error.ok()) {
+ total_result = error;
+ }
+ }
+ }
+
+ // Replace the old records with the new ones.
+ published_records_.erase(published_record_it);
+ published_records_.emplace(record, std::move(updated_record));
+
+ return total_result;
+}
+
+ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Deregistering all instances";
+
+ int removed_count = 0;
+
+ // TODO(crbug.com/openscreen/114): Trace each below call so multiple errors
+ // can be seen.
+ Error error = Error::None();
for (auto it = published_records_.begin(); it != published_records_.end();) {
- if (it->service_id() == service) {
- for (const auto& mdns_record : GetDnsRecords(*it)) {
- mdns_publisher_->DeregisterRecord(mdns_record);
+ if (it->second.service_id() == service) {
+ for (const auto& mdns_record : GetDnsRecords(it->second)) {
+ auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record);
+ if (!publisher_error.ok()) {
+ error = publisher_error;
+ }
}
removed_count++;
it = published_records_.erase(it);
@@ -41,7 +225,54 @@ size_t PublisherImpl::DeregisterAll(absl::string_view service) {
}
}
- return removed_count;
+ removed_count += EraseRecordsWithServiceId(&pending_records_, service);
+
+ if (!error.ok()) {
+ return error;
+ } else {
+ return removed_count;
+ }
+}
+
+void PublisherImpl::OnDomainFound(const DomainName& requested_name,
+ const DomainName& confirmed_name) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString()
+ << "' based on requested name: '" << requested_name.ToString()
+ << "'";
+
+ auto it = FindKey(&pending_records_, InstanceKey(requested_name));
+
+ if (it == pending_records_.end()) {
+ // This will be hit if the record was deregistered before the probe phase
+ // was completed.
+ return;
+ }
+
+ DnsSdInstanceRecord requested_record = std::move(it->first);
+ DnsSdInstanceRecord publication = requested_record;
+ Client* const client = it->second;
+ pending_records_.erase(it);
+
+ InstanceKey requested_key(requested_record);
+
+ if (requested_name != confirmed_name) {
+ OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name));
+ publication = UpdateDomain(confirmed_name, requested_record);
+ }
+
+ for (const auto& mdns_record : GetDnsRecords(publication)) {
+ Error result = mdns_publisher_->RegisterRecord(mdns_record);
+ if (!result.ok()) {
+ reporting_client_->OnRecoverableError(
+ Error(Error::Code::kRecordPublicationError, result.ToString()));
+ }
+ }
+
+ auto pair = published_records_.emplace(std::move(requested_record),
+ std::move(publication));
+ client->OnInstanceClaimed(pair.first->first, pair.first->second);
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h
index 65d3235db75..5383bdde434 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.h
@@ -9,24 +9,46 @@
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/public/dns_sd_instance_record.h"
#include "discovery/dnssd/public/dns_sd_publisher.h"
+#include "discovery/mdns/mdns_domain_confirmed_provider.h"
#include "discovery/mdns/public/mdns_service.h"
namespace openscreen {
namespace discovery {
-class PublisherImpl : public DnsSdPublisher {
+class ReportingClient;
+
+class PublisherImpl : public DnsSdPublisher,
+ public MdnsDomainConfirmedProvider {
public:
- PublisherImpl(MdnsService* publisher);
+ PublisherImpl(MdnsService* publisher,
+ ReportingClient* reporting_client,
+ TaskRunner* task_runner);
~PublisherImpl() override;
// DnsSdPublisher overrides.
- Error Register(const DnsSdInstanceRecord& record) override;
- size_t DeregisterAll(absl::string_view service) override;
+ Error Register(const DnsSdInstanceRecord& record, Client* client) override;
+ Error UpdateRegistration(const DnsSdInstanceRecord& record) override;
+ ErrorOr<int> DeregisterAll(const std::string& service) override;
private:
- std::vector<DnsSdInstanceRecord> published_records_;
+ Error UpdatePublishedRegistration(const DnsSdInstanceRecord& record);
+
+ // MdnsDomainConfirmedProvider overrides.
+ void OnDomainFound(const DomainName& requested_name,
+ const DomainName& confirmed_name) override;
+
+ // The set of records which will be published once the mDNS Probe phase
+ // completes.
+ std::map<DnsSdInstanceRecord, Client* const> pending_records_;
+
+ // Maps from the requested record to the record which was published after
+ // the mDNS Probe phase was completed. The only difference between these
+ // records should be the instance name.
+ std::map<DnsSdInstanceRecord, DnsSdInstanceRecord> published_records_;
MdnsService* const mdns_publisher_;
+ ReportingClient* const reporting_client_;
+ TaskRunner* const task_runner_;
friend class PublisherTesting;
};
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc
index fd30c567198..70c4d45425a 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl_unittest.cc
@@ -4,16 +4,28 @@
#include "discovery/dnssd/impl/publisher_impl.h"
+#include <utility>
#include <vector>
+#include "discovery/common/testing/mock_reporting_client.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
namespace openscreen {
namespace discovery {
+namespace {
using testing::_;
using testing::Return;
+using testing::StrictMock;
+
+class MockClient : public DnsSdPublisher::Client {
+ public:
+ MOCK_METHOD2(OnInstanceClaimed,
+ void(const DnsSdInstanceRecord&, const DnsSdInstanceRecord&));
+};
class MockMdnsService : public MdnsService {
public:
@@ -21,40 +33,71 @@ class MockMdnsService : public MdnsService {
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) override {
- OSP_UNIMPLEMENTED();
+ FAIL();
}
void StopQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) override {
- OSP_UNIMPLEMENTED();
+ FAIL();
}
- MOCK_METHOD1(RegisterRecord, void(const MdnsRecord& record));
- MOCK_METHOD1(DeregisterRecord, void(const MdnsRecord& record));
+ void ReinitializeQueries(const DomainName& name) override { FAIL(); }
+
+ MOCK_METHOD3(StartProbe,
+ Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress));
+ MOCK_METHOD2(UpdateRegisteredRecord,
+ Error(const MdnsRecord&, const MdnsRecord&));
+ MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord& record));
+ MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord& record));
};
-class PublisherTesting : public PublisherImpl {
+class PublisherImplTest : public testing::Test {
public:
- PublisherTesting() : PublisherImpl(&mock_service_) {}
+ PublisherImplTest()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ publisher_(&mock_service_, &reporting_client_, &task_runner_) {}
+
+ MockMdnsService* mdns_service() { return &mock_service_; }
+ TaskRunner* task_runner() { return &task_runner_; }
+ PublisherImpl* publisher() { return &publisher_; }
- MockMdnsService& mdns_service() { return mock_service_; }
+ // Calls PublisherImpl::OnDomainFound() through the public interface it
+ // implements.
+ void CallOnDomainFound(const DomainName& domain, const DomainName& domain2) {
+ static_cast<MdnsDomainConfirmedProvider&>(publisher_)
+ .OnDomainFound(domain, domain2);
+ }
private:
- MockMdnsService mock_service_;
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ StrictMock<MockMdnsService> mock_service_;
+ StrictMock<MockReportingClient> reporting_client_;
+ PublisherImpl publisher_;
};
-TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) {
- PublisherTesting publisher;
- IPAddress address = IPAddress(std::array<uint8_t, 4>{{192, 168, 0, 0}});
- DnsSdInstanceRecord record("instance", "_service._udp", "domain",
- {address, 80}, {});
+TEST_F(PublisherImplTest, TestRegistrationAndDegrestration) {
+ IPAddress address = IPAddress(192, 168, 0, 0);
+ const DomainName domain{"instance", "_service", "_udp", "domain"};
+ const DomainName domain2{"instance2", "_service", "_udp", "domain"};
+ const DnsSdInstanceRecord record("instance", "_service._udp", "domain",
+ {address, 80}, {});
+ const DnsSdInstanceRecord record2("instance2", "_service._udp", "domain",
+ {address, 80}, {});
+ MockClient client;
+
+ EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
+ publisher()->Register(record, &client);
+ testing::Mock::VerifyAndClearExpectations(mdns_service());
int seen = 0;
- EXPECT_CALL(publisher.mdns_service(), RegisterRecord(_))
+ EXPECT_CALL(*mdns_service(), RegisterRecord(_))
.Times(4)
- .WillRepeatedly([&seen, &address](const MdnsRecord& record) mutable {
+ .WillRepeatedly([&seen, &address,
+ &domain2](const MdnsRecord& record) mutable -> Error {
if (record.dns_type() == DnsType::kA) {
const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
if (data.ipv4_address() == address) {
@@ -66,15 +109,24 @@ TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) {
if (data.port() == 80) {
seen++;
}
- };
+ }
+
+ if (record.dns_type() != DnsType::kPTR) {
+ EXPECT_EQ(record.name(), domain2);
+ }
+ return Error::None();
});
- publisher.Register(record);
+ EXPECT_CALL(client, OnInstanceClaimed(record, record2));
+ CallOnDomainFound(domain, domain2);
EXPECT_EQ(seen, 2);
+ testing::Mock::VerifyAndClearExpectations(mdns_service());
+ testing::Mock::VerifyAndClearExpectations(&client);
seen = 0;
- EXPECT_CALL(publisher.mdns_service(), DeregisterRecord(_))
+ EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
.Times(4)
- .WillRepeatedly([&seen, &address](const MdnsRecord& record) mutable {
+ .WillRepeatedly([&seen,
+ &address](const MdnsRecord& record) mutable -> Error {
if (record.dns_type() == DnsType::kA) {
const ARecordRdata& data = absl::get<ARecordRdata>(record.rdata());
if (data.ipv4_address() == address) {
@@ -86,11 +138,75 @@ TEST(DnsSdPublisherImplTests, TestRegisterAndDeregister) {
if (data.port() == 80) {
seen++;
}
- };
+ }
+ return Error::None();
});
- publisher.DeregisterAll("_service._udp");
+ publisher()->DeregisterAll("_service._udp");
EXPECT_EQ(seen, 2);
}
+TEST_F(PublisherImplTest, TestUpdate) {
+ IPAddress address = IPAddress(192, 168, 0, 0);
+ DomainName domain{"instance", "_service", "_udp", "domain"};
+ DnsSdTxtRecord txt;
+ txt.SetFlag("id", true);
+ DnsSdInstanceRecord record("instance", "_service._udp", "domain",
+ {address, 80}, std::move(txt));
+ MockClient client;
+
+ // Update a non-existent record
+ EXPECT_FALSE(publisher()->UpdateRegistration(record).ok());
+
+ // Update a record during the probing phase
+ EXPECT_CALL(*mdns_service(), StartProbe(publisher(), domain, _)).Times(1);
+ EXPECT_EQ(publisher()->Register(record, &client), Error::None());
+ testing::Mock::VerifyAndClearExpectations(mdns_service());
+
+ IPAddress address2 = IPAddress(1, 2, 3, 4, 5, 6, 7, 8);
+ DnsSdTxtRecord txt2;
+ txt2.SetFlag("id2", true);
+ DnsSdInstanceRecord record2("instance", "_service._udp", "domain",
+ {address2, 80}, std::move(txt2));
+ EXPECT_EQ(publisher()->UpdateRegistration(record2), Error::None());
+
+ bool seen_v6 = false;
+ EXPECT_CALL(*mdns_service(), RegisterRecord(_))
+ .Times(4)
+ .WillRepeatedly([&seen_v6](const MdnsRecord& record) mutable -> Error {
+ EXPECT_NE(record.dns_type(), DnsType::kA);
+ if (record.dns_type() == DnsType::kAAAA) {
+ seen_v6 = true;
+ }
+ return Error::None();
+ });
+ EXPECT_CALL(client, OnInstanceClaimed(record2, record2));
+ CallOnDomainFound(domain, domain);
+ EXPECT_TRUE(seen_v6);
+ testing::Mock::VerifyAndClearExpectations(mdns_service());
+ testing::Mock::VerifyAndClearExpectations(&client);
+
+ // Update a record once it has been published.
+ EXPECT_CALL(*mdns_service(), RegisterRecord(_))
+ .WillOnce([](const MdnsRecord& record) -> Error {
+ EXPECT_EQ(record.dns_type(), DnsType::kA);
+ return Error::None();
+ });
+ EXPECT_CALL(*mdns_service(), UnregisterRecord(_))
+ .WillOnce([](const MdnsRecord& record) -> Error {
+ EXPECT_EQ(record.dns_type(), DnsType::kAAAA);
+ return Error::None();
+ });
+ EXPECT_CALL(*mdns_service(), UpdateRegisteredRecord(_, _))
+ .WillOnce(
+ [](const MdnsRecord& record, const MdnsRecord& record2) -> Error {
+ EXPECT_EQ(record.dns_type(), DnsType::kTXT);
+ EXPECT_EQ(record2.dns_type(), DnsType::kTXT);
+ return Error::None();
+ });
+ EXPECT_EQ(publisher()->UpdateRegistration(record), Error::None());
+ testing::Mock::VerifyAndClearExpectations(mdns_service());
+}
+
+} // namespace
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc
index 79f899ecfa4..37f4d59d0d7 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc
@@ -7,6 +7,7 @@
#include <string>
#include <vector>
+#include "platform/api/task_runner.h"
#include "util/logging.h"
namespace openscreen {
@@ -17,47 +18,50 @@ static constexpr char kLocalDomain[] = "local";
} // namespace
-QuerierImpl::QuerierImpl(MdnsService* mdns_querier)
- : mdns_querier_(mdns_querier) {
+QuerierImpl::QuerierImpl(MdnsService* mdns_querier, TaskRunner* task_runner)
+ : mdns_querier_(mdns_querier), task_runner_(task_runner) {
OSP_DCHECK(mdns_querier_);
+ OSP_DCHECK(task_runner_);
}
QuerierImpl::~QuerierImpl() = default;
-void QuerierImpl::StartQuery(absl::string_view service, Callback* callback) {
+void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
OSP_DCHECK(callback);
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Starting query for service '" << service << "'";
ServiceKey key(service, kLocalDomain);
- DnsQueryInfo query = GetPtrQueryInfo(key);
if (!IsQueryRunning(key)) {
- callback_map_[key] = {};
- StartDnsQuery(query);
+ callback_map_[key] = {callback};
+ StartDnsQuery(std::move(key));
} else {
- const std::vector<InstanceKey> keys = GetMatchingInstances(key);
- for (const auto& key : keys) {
- auto it = received_records_.find(key);
- if (it == received_records_.end()) {
- continue;
- }
-
- ErrorOr<DnsSdInstanceRecord> record = it->second.CreateRecord();
- if (record.is_value()) {
- callback->OnInstanceCreated(record.value());
+ callback_map_[key].push_back(callback);
+
+ for (auto& kvp : received_records_) {
+ if (kvp.first == key) {
+ ErrorOr<DnsSdInstanceRecord> record = kvp.second.CreateRecord();
+ if (record.is_value()) {
+ callback->OnInstanceCreated(record.value());
+ }
}
}
}
- callback_map_[key].push_back(callback);
}
-bool QuerierImpl::IsQueryRunning(absl::string_view service) const {
+bool QuerierImpl::IsQueryRunning(const std::string& service) const {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
return IsQueryRunning(ServiceKey(service, kLocalDomain));
}
-void QuerierImpl::StopQuery(absl::string_view service, Callback* callback) {
+void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
OSP_DCHECK(callback);
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Stopping query for service '" << service << "'";
ServiceKey key(service, kLocalDomain);
- DnsQueryInfo query = GetPtrQueryInfo(key);
auto callback_it = callback_map_.find(key);
if (callback_it == callback_map_.end()) {
return;
@@ -68,15 +72,42 @@ void QuerierImpl::StopQuery(absl::string_view service, Callback* callback) {
if (it != callbacks->end()) {
callbacks->erase(it);
if (callbacks->empty()) {
- EraseInstancesOf(key);
callback_map_.erase(callback_it);
- StopDnsQuery(query);
+ StopDnsQuery(std::move(key));
}
}
}
+void QuerierImpl::ReinitializeQueries(const std::string& service) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Re-initializing query for service '" << service << "'";
+
+ const ServiceKey key(service, kLocalDomain);
+
+ // Stop instance-specific queries and erase all instance data received so far.
+ std::vector<InstanceKey> keys_to_remove;
+ for (const auto& pair : received_records_) {
+ if (key == pair.first) {
+ keys_to_remove.push_back(pair.first);
+ }
+ }
+ for (InstanceKey& ik : keys_to_remove) {
+ StopDnsQuery(std::move(ik), false);
+ }
+
+ // Restart top-level queries.
+ mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
+}
+
void QuerierImpl::OnRecordChanged(const MdnsRecord& record,
RecordChangedEvent event) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ OSP_DVLOG << "Record with name '" << record.name().ToString()
+ << "' and type '" << record.dns_type()
+ << "' has received change of type '" << event << "'";
+
IsPtrRecord(record) ? HandlePtrRecordChange(record, event)
: HandleNonPtrRecordChange(record, event);
}
@@ -88,15 +119,12 @@ Error QuerierImpl::HandlePtrRecordChange(const MdnsRecord& record,
return Error::Code::kParameterInvalid;
}
- InstanceKey key(record);
- DnsQueryInfo query = GetInstanceQueryInfo(key);
switch (event) {
case RecordChangedEvent::kCreated:
- StartDnsQuery(query);
+ StartDnsQuery(InstanceKey(record));
return Error::None();
case RecordChangedEvent::kExpired:
- StopDnsQuery(query);
- EraseInstancesOf(ServiceKey(key));
+ StopDnsQuery(InstanceKey(record));
return Error::None();
case RecordChangedEvent::kUpdated:
return Error::Code::kOperationInvalid;
@@ -165,54 +193,65 @@ void QuerierImpl::NotifyCallbacks(
}
}
-void QuerierImpl::EraseInstancesOf(const ServiceKey& key) {
- std::vector<InstanceKey> keys = GetMatchingInstances(key);
- const auto it = callback_map_.find(key);
- std::vector<Callback*> callbacks;
- if (it != callback_map_.end()) {
- callbacks = it->second;
+void QuerierImpl::StartDnsQuery(InstanceKey key) {
+ auto pair = received_records_.emplace(key, DnsData(key));
+ if (!pair.second) {
+ // This means that a query is already ongoing.
+ return;
}
- for (const auto& key : keys) {
- auto recieved_record = received_records_.find(key);
- if (recieved_record == received_records_.end()) {
- continue;
- }
+ DnsQueryInfo query = GetInstanceQueryInfo(key);
+ mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this);
+}
+
+void QuerierImpl::StopDnsQuery(InstanceKey key, bool should_inform_callbacks) {
+ // If the instance is not being queried for, return.
+ auto record_it = received_records_.find(key);
+ if (record_it == received_records_.end()) {
+ return;
+ }
- ErrorOr<DnsSdInstanceRecord> instance_record =
- recieved_record->second.CreateRecord();
- if (instance_record.is_value()) {
- for (Callback* callback : callbacks) {
+ // If the instance has enough associated data that an instance was provided to
+ // the higher layer, call the deleted callback for all associated callbacks.
+ ErrorOr<DnsSdInstanceRecord> instance_record =
+ record_it->second.CreateRecord();
+ if (should_inform_callbacks && instance_record.is_value()) {
+ const auto it = callback_map_.find(key);
+ if (it != callback_map_.end()) {
+ for (Callback* callback : it->second) {
callback->OnInstanceDeleted(instance_record.value());
}
}
-
- received_records_.erase(recieved_record);
}
-}
-std::vector<InstanceKey> QuerierImpl::GetMatchingInstances(
- const ServiceKey& key) {
- // Because only one or two PTR queries are expected at a time, expect >=1/2 of
- // the records to be associated with a given PTR. They can't be removed in
- // less than O(n) time, so just iterate across them all.
- std::vector<InstanceKey> keys;
- for (auto it = received_records_.begin(); it != received_records_.end();
- it++) {
- if (it->first.IsInstanceOf(key)) {
- keys.push_back(it->first);
- }
- }
+ // Erase the key to mark the instance as no longer being queried for.
+ received_records_.erase(record_it);
- return keys;
+ // Call to the mDNS layer to stop the query.
+ DnsQueryInfo query = GetInstanceQueryInfo(key);
+ mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this);
}
-void QuerierImpl::StartDnsQuery(const DnsQueryInfo& query) {
+void QuerierImpl::StartDnsQuery(ServiceKey key) {
+ DnsQueryInfo query = GetPtrQueryInfo(key);
mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this);
}
-void QuerierImpl::StopDnsQuery(const DnsQueryInfo& query) {
+void QuerierImpl::StopDnsQuery(ServiceKey key) {
+ DnsQueryInfo query = GetPtrQueryInfo(key);
mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this);
+
+ // Stop any ongoing instance-specific queries.
+ std::vector<InstanceKey> keys_to_remove;
+ for (const auto& pair : received_records_) {
+ const bool key_is_service_from_query = (key == pair.first);
+ if (key_is_service_from_query) {
+ keys_to_remove.push_back(pair.first);
+ }
+ }
+ for (auto it = keys_to_remove.begin(); it != keys_to_remove.end(); it++) {
+ StopDnsQuery(std::move(*it));
+ }
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h
index a39c91b58e9..f64fc1ada58 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h
@@ -27,19 +27,19 @@ namespace discovery {
class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
public:
- // |querier| must outlive the QuerierImpl instance constructed.
- explicit QuerierImpl(MdnsService* querier);
+ // |querier| and |task_runner| must outlive the QuerierImpl instance
+ // constructed.
+ QuerierImpl(MdnsService* querier, TaskRunner* task_runner);
~QuerierImpl() override;
- bool IsQueryRunning(absl::string_view service) const;
+ bool IsQueryRunning(const std::string& service) const;
// DnsSdQuerier overrides.
- void StartQuery(absl::string_view service, Callback* callback) override;
- void StopQuery(absl::string_view service, Callback* callback) override;
+ void StartQuery(const std::string& service, Callback* callback) override;
+ void StopQuery(const std::string& service, Callback* callback) override;
+ void ReinitializeQueries(const std::string& service) override;
// MdnsRecordChangedCallback overrides.
- // TODO(rwkeane): Ensure this is run on the TaskRunner thread once the
- // underlying mDNS implementation can be overridden.
void OnRecordChanged(const MdnsRecord& record,
RecordChangedEvent event) override;
@@ -58,12 +58,10 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
}
// Initiates or terminates queries on the mdns_querier_ object.
- void StartDnsQuery(const DnsQueryInfo& query);
- void StopDnsQuery(const DnsQueryInfo& query);
-
- // Erases all instance records describing services matching the provided key
- // and informs all callbacks associated with the given key of their deletion.
- void EraseInstancesOf(const ServiceKey& service);
+ void StartDnsQuery(InstanceKey key);
+ void StartDnsQuery(ServiceKey key);
+ void StopDnsQuery(InstanceKey key, bool should_inform_callbacks = true);
+ void StopDnsQuery(ServiceKey key);
// Calls the appropriate callback method based on the provided Instance Record
// values.
@@ -71,12 +69,12 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
const ErrorOr<DnsSdInstanceRecord>& old_record,
const ErrorOr<DnsSdInstanceRecord>& new_record);
- // Returns all InstanceKeys received so far which represent instances of
- // the service described by the provided ServiceKey.
- std::vector<InstanceKey> GetMatchingInstances(const ServiceKey& key);
-
// Map from a specific service instance to the data received so far about
- // that instance.
+ // that instance. The keys in this map are the instances for which an
+ // associated PTR record has been received, and the values are the set of
+ // non-PTR records received which describe that service (if any). Note that,
+ // with this definition, it is possible for a InstanceKey to be mapped to an
+ // empty DnsData if the instance has no associated records yet.
std::unordered_map<InstanceKey, DnsData, absl::Hash<InstanceKey>>
received_records_;
@@ -85,6 +83,7 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
std::map<ServiceKey, std::vector<Callback*>> callback_map_;
MdnsService* const mdns_querier_;
+ TaskRunner* const task_runner_;
friend class QuerierImplTesting;
};
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc
index 327eeccc8f3..9da8a4543dc 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc
@@ -5,6 +5,8 @@
#include "discovery/dnssd/impl/querier_impl.h"
#include <memory>
+#include <string>
+#include <utility>
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
@@ -13,6 +15,8 @@
#include "discovery/mdns/testing/mdns_test_util.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
#include "util/logging.h"
namespace openscreen {
@@ -36,9 +40,15 @@ class MockMdnsService : public MdnsService {
StopQuery,
void(const DomainName&, DnsType, DnsClass, MdnsRecordChangedCallback*));
- void RegisterRecord(const MdnsRecord& record) override { FAIL(); }
+ MOCK_METHOD1(ReinitializeQueries, void(const DomainName& name));
- void DeregisterRecord(const MdnsRecord& record) override { FAIL(); }
+ // Unused.
+ MOCK_METHOD3(StartProbe,
+ Error(MdnsDomainConfirmedProvider*, DomainName, IPAddress));
+ MOCK_METHOD1(RegisterRecord, Error(const MdnsRecord&));
+ MOCK_METHOD1(UnregisterRecord, Error(const MdnsRecord&));
+ MOCK_METHOD2(UpdateRegisteredRecord,
+ Error(const MdnsRecord&, const MdnsRecord&));
};
SrvRecordRdata CreateSrvRecord() {
@@ -52,8 +62,8 @@ ARecordRdata CreateARecord() {
}
AAAARecordRdata CreateAAAARecord() {
- return AAAARecordRdata(
- IPAddress{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ return AAAARecordRdata(IPAddress(0x0102, 0x0304, 0x0506, 0x0708, 0x090a,
+ 0x0b0c, 0x0d0e, 0x0f10));
}
MdnsRecord CreatePtrRecord(const std::string& instance,
@@ -94,7 +104,7 @@ using testing::StrictMock;
class DnsDataAccessor {
public:
- DnsDataAccessor(DnsData* data) : data_(data) {}
+ explicit DnsDataAccessor(DnsData* data) : data_(data) {}
void set_srv(absl::optional<SrvRecordRdata> record) { data_->srv_ = record; }
void set_txt(absl::optional<TxtRecordRdata> record) { data_->txt_ = record; }
@@ -116,7 +126,10 @@ class DnsDataAccessor {
class QuerierImplTesting : public QuerierImpl {
public:
- QuerierImplTesting() : QuerierImpl(&mock_service_) {}
+ QuerierImplTesting()
+ : QuerierImpl(&mock_service_, &task_runner_),
+ clock_(Clock::now()),
+ task_runner_(&clock_) {}
MockMdnsService* service() { return &mock_service_; }
@@ -140,6 +153,8 @@ class QuerierImplTesting : public QuerierImpl {
}
private:
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
StrictMock<MockMdnsService> mock_service_;
};
@@ -153,6 +168,7 @@ class DnsSdQuerierImplTest : public testing::Test {
.Times(1);
querier.StartQuery(service, &callback);
EXPECT_TRUE(querier.IsQueryRunning(service));
+ testing::Mock::VerifyAndClearExpectations(querier.service());
EXPECT_CALL(*querier.service(),
StartQuery(_, DnsType::kPTR, DnsClass::kANY, _))
@@ -201,6 +217,9 @@ TEST_F(DnsSdQuerierImplTest, TestStopQueryClearsRecords) {
EXPECT_CALL(*querier.service(),
StopQuery(_, DnsType::kPTR, DnsClass::kANY, _))
.Times(1);
+ EXPECT_CALL(*querier.service(),
+ StopQuery(_, DnsType::kANY, DnsClass::kANY, _))
+ .Times(1);
querier.StopQuery(service, &callback);
EXPECT_FALSE(querier.GetDnsData(instance, service, domain).has_value());
}
@@ -221,6 +240,7 @@ TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) {
StartQuery(_, DnsType::kANY, DnsClass::kANY, _))
.Times(1);
querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
+ testing::Mock::VerifyAndClearExpectations(querier.service());
EXPECT_CALL(*querier.service(),
StopQuery(_, DnsType::kANY, DnsClass::kANY, _))
@@ -230,6 +250,12 @@ TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) {
TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) {
auto ptr = CreatePtrRecord(instance, service, domain);
+ EXPECT_CALL(*querier.service(),
+ StartQuery(_, DnsType::kANY, DnsClass::kANY, _))
+ .Times(1);
+ querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
+ testing::Mock::VerifyAndClearExpectations(querier.service());
+
DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain);
dns_data.set_srv(CreateSrvRecord());
dns_data.set_txt(MakeTxtRecord({}));
@@ -237,11 +263,6 @@ TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) {
dns_data.set_aaaa(CreateAAAARecord());
ASSERT_TRUE(dns_data.CanCreateInstance());
- EXPECT_CALL(*querier.service(),
- StartQuery(_, DnsType::kANY, DnsClass::kANY, _))
- .Times(1);
- querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
-
EXPECT_CALL(callback, OnInstanceDeleted(_)).Times(1);
EXPECT_CALL(*querier.service(),
StopQuery(_, DnsType::kANY, DnsClass::kANY, _))
@@ -276,9 +297,11 @@ TEST_F(DnsSdQuerierImplTest, BothNewAndOldValidRecords) {
EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1);
querier.OnRecordChanged(a_record, RecordChangedEvent::kCreated);
+ testing::Mock::VerifyAndClearExpectations(&callback);
EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1);
querier.OnRecordChanged(a_record, RecordChangedEvent::kUpdated);
+ testing::Mock::VerifyAndClearExpectations(&callback);
auto aaaa_rdata = CreateAAAARecord();
MdnsRecord aaaa_record(kDomainName, DnsType::kAAAA, DnsClass::kIN,
@@ -287,9 +310,11 @@ TEST_F(DnsSdQuerierImplTest, BothNewAndOldValidRecords) {
EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1);
querier.OnRecordChanged(aaaa_record, RecordChangedEvent::kUpdated);
+ testing::Mock::VerifyAndClearExpectations(&callback);
EXPECT_CALL(callback, OnInstanceUpdated(_)).Times(1);
querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired);
+ testing::Mock::VerifyAndClearExpectations(&callback);
}
TEST_F(DnsSdQuerierImplTest, OnlyNewRecordValid) {
@@ -321,5 +346,35 @@ TEST_F(DnsSdQuerierImplTest, OnlyOldRecordValid) {
querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired);
}
+TEST_F(DnsSdQuerierImplTest, HardRefresh) {
+ const std::string service2 = "_service2._udp";
+
+ DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain);
+ dns_data.set_srv(CreateSrvRecord());
+ dns_data.set_txt(MakeTxtRecord({}));
+ dns_data.set_a(CreateARecord());
+ dns_data.set_aaaa(CreateAAAARecord());
+ DnsDataAccessor dns_data2 = querier.CreateDnsData(instance, service2, domain);
+ dns_data2.set_srv(CreateSrvRecord());
+
+ EXPECT_CALL(callback, OnInstanceCreated(_)).Times(1);
+ querier.StartQuery(service, &callback);
+ EXPECT_TRUE(querier.IsQueryRunning(service));
+
+ const DomainName ptr_domain{"_service", "_udp", "local"};
+ const DomainName instance_domain{"instance", "_service", "_udp", "local"};
+ EXPECT_CALL(*querier.service(), ReinitializeQueries(ptr_domain));
+ EXPECT_CALL(*querier.service(), StopQuery(instance_domain, _, _, _));
+ querier.ReinitializeQueries(service);
+ testing::Mock::VerifyAndClearExpectations(querier.service());
+
+ absl::optional<DnsDataAccessor> data =
+ querier.GetDnsData(instance, service, domain);
+ EXPECT_EQ(data, absl::nullopt);
+ data = querier.GetDnsData(instance, service2, domain);
+ EXPECT_NE(data, absl::nullopt);
+ EXPECT_TRUE(querier.IsQueryRunning(service));
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc
index e1caaabb049..ba8ce1518db 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.cc
@@ -6,22 +6,60 @@
#include <utility>
+#include "discovery/common/config.h"
#include "discovery/mdns/public/mdns_service.h"
+#include "platform/api/task_runner.h"
namespace openscreen {
namespace discovery {
+namespace {
+
+MdnsService::SupportedNetworkAddressFamily GetSupportedEndpointTypes(
+ const InterfaceInfo& interface) {
+ MdnsService::SupportedNetworkAddressFamily supported_types =
+ MdnsService::kNoAddressFamily;
+ if (interface.GetIpAddressV4()) {
+ supported_types = supported_types | MdnsService::kUseIpV4Multicast;
+ }
+ if (interface.GetIpAddressV6()) {
+ supported_types = supported_types | MdnsService::kUseIpV6Multicast;
+ }
+ return supported_types;
+}
+
+} // namespace
// static
-std::unique_ptr<DnsSdService> DnsSdService::Create(TaskRunner* task_runner) {
- return std::make_unique<ServiceImpl>(task_runner);
+SerialDeletePtr<DnsSdService> CreateDnsSdService(
+ TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config) {
+ return SerialDeletePtr<DnsSdService>(
+ task_runner, new ServiceImpl(task_runner, reporting_client, config));
}
-ServiceImpl::ServiceImpl(TaskRunner* task_runner)
- : mdns_service_(MdnsService::Create(task_runner)),
- querier_(mdns_service_.get()),
- publisher_(mdns_service_.get()) {}
+ServiceImpl::ServiceImpl(TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config)
+ : task_runner_(task_runner),
+ mdns_service_(
+ MdnsService::Create(task_runner,
+ reporting_client,
+ config,
+ config.interface.index,
+ GetSupportedEndpointTypes(config.interface))) {
+ if (config.enable_querying) {
+ querier_ = std::make_unique<QuerierImpl>(mdns_service_.get(), task_runner_);
+ }
+ if (config.enable_publication) {
+ publisher_ = std::make_unique<PublisherImpl>(
+ mdns_service_.get(), reporting_client, task_runner_);
+ }
+}
-ServiceImpl::~ServiceImpl() = default;
+ServiceImpl::~ServiceImpl() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+}
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h
index 829cceb9284..3640407e46d 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_impl.h
@@ -10,32 +10,34 @@
#include "discovery/dnssd/impl/publisher_impl.h"
#include "discovery/dnssd/impl/querier_impl.h"
#include "discovery/dnssd/public/dns_sd_service.h"
+#include "platform/base/interface_info.h"
namespace openscreen {
-namespace platform {
class TaskRunner;
-}
-
namespace discovery {
class MdnsService;
class ServiceImpl final : public DnsSdService {
public:
- explicit ServiceImpl(TaskRunner* task_runner);
+ ServiceImpl(TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config);
~ServiceImpl() override;
// DnsSdService overrides.
- DnsSdQuerier* Querier() override { return &querier_; }
- DnsSdPublisher* Publisher() override { return &publisher_; }
+ DnsSdQuerier* GetQuerier() override { return querier_.get(); }
+ DnsSdPublisher* GetPublisher() override { return publisher_.get(); }
private:
+ TaskRunner* const task_runner_;
+
std::unique_ptr<MdnsService> mdns_service_;
- QuerierImpl querier_;
- PublisherImpl publisher_;
+ std::unique_ptr<QuerierImpl> querier_;
+ std::unique_ptr<PublisherImpl> publisher_;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc
index 5172ec4a9f9..e5af9d3ba89 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc
@@ -7,7 +7,6 @@
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "discovery/dnssd/impl/conversion_layer.h"
-#include "discovery/dnssd/impl/instance_key.h"
#include "discovery/mdns/mdns_records.h"
#include "discovery/mdns/public/mdns_constants.h"
@@ -17,15 +16,18 @@ namespace discovery {
// The InstanceKey ctor used below cares about the Instance ID of the
// MdnsRecord, while this class doesn't, so it's possible that the InstanceKey
// ctor will fail for this reason. That's not a problem though - A failure when
-// creating an InstanceKey would mean that the record we recieved was invalid,
+// creating an InstanceKey would mean that the record we received was invalid,
// so there is no reason to continue processing.
-ServiceKey::ServiceKey(const MdnsRecord& record)
- : ServiceKey(InstanceKey(record)) {}
+ServiceKey::ServiceKey(const MdnsRecord& record) {
+ ErrorOr<ServiceKey> key = TryCreate(record);
+ OSP_DCHECK(key.is_value());
+ *this = std::move(key.value());
+}
-ServiceKey::ServiceKey(const InstanceKey& key)
- : ServiceKey(key.service_id(), key.domain_id()) {
- OSP_DCHECK(IsServiceValid(service_id_));
- OSP_DCHECK(IsDomainValid(domain_id_));
+ServiceKey::ServiceKey(const DomainName& domain) {
+ ErrorOr<ServiceKey> key = TryCreate(domain);
+ OSP_DCHECK(key.is_value());
+ *this = std::move(key.value());
}
ServiceKey::ServiceKey(absl::string_view service, absl::string_view domain)
@@ -41,5 +43,36 @@ ServiceKey::ServiceKey(ServiceKey&& other) = default;
ServiceKey& ServiceKey::operator=(const ServiceKey& rhs) = default;
ServiceKey& ServiceKey::operator=(ServiceKey&& rhs) = default;
+// static
+ErrorOr<ServiceKey> ServiceKey::TryCreate(const MdnsRecord& record) {
+ return TryCreate(GetDomainName(record));
+}
+
+// static
+ErrorOr<ServiceKey> ServiceKey::TryCreate(const DomainName& names) {
+ // Size must be at least 4, because the minimum valid label is of the form
+ // <instance>.<service type>.<protocol>.<domain>
+ if (names.labels().size() < 4) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ // Skip the InstanceId.
+ auto it = ++names.labels().begin();
+
+ std::string service_name = *it++;
+ const std::string protocol = *it++;
+ const std::string service_id = service_name.append(".").append(protocol);
+ if (!IsServiceValid(service_id)) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ const std::string domain_id = absl::StrJoin(it, names.labels().end(), ".");
+ if (!IsDomainValid(domain_id)) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ return ServiceKey(service_id, domain_id);
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h
index bdb5e0ee003..4e6ff7fae8d 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h
@@ -9,24 +9,24 @@
#include <utility>
#include "absl/strings/string_view.h"
+#include "platform/base/error.h"
namespace openscreen {
namespace discovery {
-class InstanceKey;
+class DomainName;
class MdnsRecord;
// This class is intended to be used as the key of a std::unordered_map or a
// std::map when referencing data related to a service type
class ServiceKey {
public:
- // NOTE: The record provided must have valid service, domain, and instance
- // labels.
+ // NOTE: The record provided must have valid service domain labels.
explicit ServiceKey(const MdnsRecord& record);
+ explicit ServiceKey(const DomainName& domain);
// NOTE: The provided service and domain labels must be valid.
ServiceKey(absl::string_view service, absl::string_view domain);
- explicit ServiceKey(const InstanceKey& key);
ServiceKey(const ServiceKey& other);
ServiceKey(ServiceKey&& other);
@@ -37,6 +37,9 @@ class ServiceKey {
const std::string& domain_id() const { return domain_id_; }
private:
+ static ErrorOr<ServiceKey> TryCreate(const MdnsRecord& record);
+ static ErrorOr<ServiceKey> TryCreate(const DomainName& domain);
+
std::string service_id_;
std::string domain_id_;
@@ -44,6 +47,12 @@ class ServiceKey {
friend H AbslHashValue(H h, const ServiceKey& key);
friend bool operator<(const ServiceKey& lhs, const ServiceKey& rhs);
+
+ // Validation method which needs the same code as CreateFromRecord(). Use a
+ // friend declaration to avoid duplicating this code while still keeping the
+ // factory private.
+ friend bool HasValidDnsRecordAddress(const MdnsRecord& record);
+ friend bool HasValidDnsRecordAddress(const DomainName& domain);
};
// Hashing functions to allow for using with absl::Hash<...>.
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc
index e0aaea9b2d4..90506a6bf2d 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.cc
@@ -4,14 +4,15 @@
#include "discovery/dnssd/public/dns_sd_instance_record.h"
-#include "absl/strings/ascii.h"
+#include <cctype>
+
#include "util/logging.h"
namespace openscreen {
namespace discovery {
namespace {
-bool IsValidUtf8(absl::string_view string) {
+bool IsValidUtf8(const std::string& string) {
for (size_t i = 0; i < string.size(); i++) {
if (string[i] >> 5 == 0x06) { // 110xxxxx 10xxxxxx
if (i + 1 >= string.size() || string[++i] >> 6 != 0x02) {
@@ -34,7 +35,7 @@ bool IsValidUtf8(absl::string_view string) {
return true;
}
-bool HasControlCharacters(absl::string_view string) {
+bool HasControlCharacters(const std::string& string) {
for (auto ch : string) {
if ((ch >= 0x0 && ch <= 0x1F /* Ascii control characters */) ||
ch == 0x7F /* DEL character */) {
@@ -55,6 +56,7 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id,
std::move(service_id),
std::move(domain_id),
std::move(txt)) {
+ OSP_DCHECK(endpoint);
if (endpoint.address.IsV4()) {
address_v4_ = std::move(endpoint);
} else if (endpoint.address.IsV6()) {
@@ -74,6 +76,8 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id,
std::move(service_id),
std::move(domain_id),
std::move(txt)) {
+ OSP_CHECK(ipv4_endpoint);
+ OSP_CHECK(ipv6_endpoint);
OSP_CHECK(ipv4_endpoint.address.IsV4());
OSP_CHECK(ipv6_endpoint.address.IsV6());
@@ -94,18 +98,11 @@ DnsSdInstanceRecord::DnsSdInstanceRecord(std::string instance_id,
OSP_DCHECK(IsDomainValid(domain_id_));
}
-bool DnsSdInstanceRecord::operator==(const DnsSdInstanceRecord& other) const {
- return instance_id_ == other.instance_id_ &&
- service_id_ == other.service_id_ && domain_id_ == other.domain_id_ &&
- address_v4_ == other.address_v4_ && address_v6_ == other.address_v6_ &&
- txt_ == other.txt_;
-}
-
uint16_t DnsSdInstanceRecord::port() const {
- if (address_v4_.has_value()) {
- return address_v4_.value().port;
- } else if (address_v6_.has_value()) {
- return address_v6_.value().port;
+ if (address_v4_) {
+ return address_v4_.port;
+ } else if (address_v6_) {
+ return address_v6_.port;
} else {
OSP_NOTREACHED();
return 0;
@@ -113,7 +110,7 @@ uint16_t DnsSdInstanceRecord::port() const {
}
// static
-bool IsInstanceValid(absl::string_view instance) {
+bool IsInstanceValid(const std::string& instance) {
// According to RFC6763, Instance names must:
// - Be encoded in Net-Unicode (which required UTF-8 formatting).
// - NOT contain ASCII control characters
@@ -124,7 +121,7 @@ bool IsInstanceValid(absl::string_view instance) {
}
// static
-bool IsServiceValid(absl::string_view service) {
+bool IsServiceValid(const std::string& service) {
// According to RFC6763, the service name "consists of a pair of DNS labels".
// "The first label of the pair is an underscore character followed by the
// Service Name" and "The second label is either '_tcp' [...] or '_udp'".
@@ -138,7 +135,7 @@ bool IsServiceValid(absl::string_view service) {
return false;
}
- const absl::string_view protocol = service.substr(service.size() - 5);
+ const std::string protocol = service.substr(service.size() - 5);
if (protocol != "._udp" && protocol != "._tcp") {
return false;
}
@@ -156,9 +153,9 @@ bool IsServiceValid(absl::string_view service) {
return false;
}
last_char_hyphen = true;
- } else if (absl::ascii_isalpha(service[i])) {
+ } else if (std::isalpha(service[i])) {
seen_letter = true;
- } else if (!absl::ascii_isdigit(service[i])) {
+ } else if (!std::isdigit(service[i])) {
return false;
}
}
@@ -167,7 +164,7 @@ bool IsServiceValid(absl::string_view service) {
}
// static
-bool IsDomainValid(absl::string_view domain) {
+bool IsDomainValid(const std::string& domain) {
// As RFC6763 Section 4.1.3 provides no validation requirements for the domain
// section, the following validations are used:
// - All labels must be no longer than 63 characters
@@ -191,5 +188,32 @@ bool IsDomainValid(absl::string_view domain) {
return !HasControlCharacters(domain) && IsValidUtf8(domain);
}
+bool operator<(const DnsSdInstanceRecord& lhs, const DnsSdInstanceRecord& rhs) {
+ int comp = lhs.instance_id_.compare(rhs.instance_id_);
+ if (comp != 0) {
+ return comp < 0;
+ }
+
+ comp = lhs.service_id_.compare(rhs.service_id_);
+ if (comp != 0) {
+ return comp < 0;
+ }
+
+ comp = lhs.domain_id_.compare(rhs.domain_id_);
+ if (comp != 0) {
+ return comp < 0;
+ }
+
+ if (lhs.address_v4_ != rhs.address_v4_) {
+ return lhs.address_v4_ < rhs.address_v4_;
+ }
+
+ if (lhs.address_v6_ != rhs.address_v6_) {
+ return lhs.address_v6_ < rhs.address_v6_;
+ }
+
+ return lhs.txt_ < rhs.txt_;
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h
index 5461165fdde..9dbb8ffa6ac 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_record.h
@@ -5,17 +5,15 @@
#ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_RECORD_H_
#define DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_RECORD_H_
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
#include "discovery/dnssd/public/dns_sd_txt_record.h"
#include "platform/base/ip_address.h"
namespace openscreen {
namespace discovery {
-bool IsInstanceValid(absl::string_view instance);
-bool IsServiceValid(absl::string_view service);
-bool IsDomainValid(absl::string_view domain);
+bool IsInstanceValid(const std::string& instance);
+bool IsServiceValid(const std::string& service);
+bool IsDomainValid(const std::string& domain);
// Represents the data stored in DNS records of types SRV, TXT, A, and AAAA
class DnsSdInstanceRecord {
@@ -45,10 +43,10 @@ class DnsSdInstanceRecord {
// Returns the domain id for this DNS-SD record.
const std::string& domain_id() const { return domain_id_; }
- // Returns the addess associated with this DNS-SD record. In any valid record,
- // at least one will be set.
- const absl::optional<IPEndpoint>& address_v4() const { return address_v4_; }
- const absl::optional<IPEndpoint>& address_v6() const { return address_v6_; }
+ // Returns the address associated with this DNS-SD record. In any valid
+ // record, at least one will be set.
+ const IPEndpoint& address_v4() const { return address_v4_; }
+ const IPEndpoint& address_v6() const { return address_v6_; }
// Returns the TXT record associated with this DNS-SD record
const DnsSdTxtRecord& txt() const { return txt_; }
@@ -56,12 +54,6 @@ class DnsSdInstanceRecord {
// Returns the port associated with this instance record.
uint16_t port() const;
- bool operator==(const DnsSdInstanceRecord& other) const;
-
- inline bool operator!=(const DnsSdInstanceRecord& other) const {
- return !(*this == other);
- }
-
private:
DnsSdInstanceRecord(std::string instance_id,
std::string service_id,
@@ -71,11 +63,41 @@ class DnsSdInstanceRecord {
std::string instance_id_;
std::string service_id_;
std::string domain_id_;
- absl::optional<IPEndpoint> address_v4_;
- absl::optional<IPEndpoint> address_v6_;
+ IPEndpoint address_v4_;
+ IPEndpoint address_v6_;
DnsSdTxtRecord txt_;
+
+ friend bool operator<(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs);
};
+bool operator<(const DnsSdInstanceRecord& lhs, const DnsSdInstanceRecord& rhs);
+
+inline bool operator>(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs) {
+ return rhs < lhs;
+}
+
+inline bool operator<=(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs) {
+ return !(rhs > lhs);
+}
+
+inline bool operator>=(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs) {
+ return !(rhs < lhs);
+}
+
+inline bool operator==(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs) {
+ return lhs <= rhs && lhs >= rhs;
+}
+
+inline bool operator!=(const DnsSdInstanceRecord& lhs,
+ const DnsSdInstanceRecord& rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h
index 638a30f782d..d271b684c7c 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_publisher.h
@@ -5,7 +5,6 @@
#ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_PUBLISHER_H_
#define DISCOVERY_DNSSD_PUBLIC_DNS_SD_PUBLISHER_H_
-#include "absl/strings/string_view.h"
#include "discovery/dnssd/public/dns_sd_instance_record.h"
#include "platform/base/error.h"
@@ -14,6 +13,20 @@ namespace discovery {
class DnsSdPublisher {
public:
+ class Client {
+ public:
+ virtual ~Client() = default;
+
+ // Callback called when a record is successfully claimed and published via
+ // the Register() method. These records are expected to only differ in
+ // the DnsSdInstanceRecord::instance_id() field, or to be equal. This
+ // callback is purely for informational purposes and the caller is not
+ // required to act on it.
+ virtual void OnInstanceClaimed(
+ const DnsSdInstanceRecord& requested_record,
+ const DnsSdInstanceRecord& claimed_record) = 0;
+ };
+
virtual ~DnsSdPublisher() = default;
// Publishes the PTR, SRV, TXT, A, and AAAA records provided in the
@@ -22,12 +35,21 @@ class DnsSdPublisher {
// NOTE: Some embedders may return errors on other conditions (for instance,
// android will return an error if the resulting TXT record has values not
// encodable with UTF8).
- virtual Error Register(const DnsSdInstanceRecord& record) = 0;
+ virtual Error Register(const DnsSdInstanceRecord& record, Client* client) = 0;
+
+ // Updates the TXT, A, and AAAA records associated with the provided record,
+ // if any changes have occurred. The instance and domain names must match
+ // those of a previously published record. If either this is not true, no
+ // changes have occurred, or additional embedder-specific requirements have
+ // been violated, an error is returned. Else, Error::None is returned.
+ virtual Error UpdateRegistration(const DnsSdInstanceRecord& record) = 0;
// Unpublishes any PTR, SRV, TXT, A, and AAAA records associated with this
- // service id. If no such records are published, this operation will be a
- // no-op. Returns the number of records which were removed.
- virtual size_t DeregisterAll(absl::string_view service) = 0;
+ // service id, where the service id is the second part of the
+ // <instance>.<service>.<domain> domain name as described in RFC 6763. If no
+ // such records are published, this operation will be a no-op. Returns the
+ // number of records which were removed, or an error code on error.
+ virtual ErrorOr<int> DeregisterAll(const std::string& service) = 0;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h
index 5a35f62d2d4..083d17d53cd 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_querier.h
@@ -5,7 +5,6 @@
#ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_QUERIER_H_
#define DISCOVERY_DNSSD_PUBLIC_DNS_SD_QUERIER_H_
-#include "absl/strings/string_view.h"
#include "discovery/dnssd/public/dns_sd_instance_record.h"
namespace openscreen {
@@ -39,12 +38,17 @@ class DnsSdQuerier {
// NOTE: The provided service value is expected to be valid, as defined by the
// IsServiceValid() method.
// NOTE: The callback must be called on the TaskRunner thread.
- virtual void StartQuery(absl::string_view service, Callback* cb) = 0;
+ virtual void StartQuery(const std::string& service, Callback* cb) = 0;
// Stops an already running query.
// NOTE: The provided service value is expected to be valid, as defined by the
// IsServiceValid() method.
- virtual void StopQuery(absl::string_view service, Callback* cb) = 0;
+ virtual void StopQuery(const std::string& service, Callback* cb) = 0;
+
+ // Re-initializes the process of service discovery for the provided service
+ // id. All ongoing queries for this domain are restarted and any previously
+ // received query results are discarded.
+ virtual void ReinitializeQueries(const std::string& service) = 0;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h
index 65cc83bef18..f0fe130bf88 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_service.h
@@ -5,8 +5,14 @@
#ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_SERVICE_H_
#define DISCOVERY_DNSSD_PUBLIC_DNS_SD_SERVICE_H_
+#include <functional>
#include <memory>
+#include "platform/base/error.h"
+#include "platform/base/interface_info.h"
+#include "platform/base/ip_address.h"
+#include "util/serial_delete_ptr.h"
+
namespace openscreen {
struct IPEndpoint;
@@ -14,8 +20,10 @@ class TaskRunner;
namespace discovery {
+struct Config;
class DnsSdPublisher;
class DnsSdQuerier;
+class ReportingClient;
// This class provides a wrapper around DnsSdQuerier and DnsSdPublisher to
// allow for an embedder-overridable factory method below.
@@ -23,17 +31,13 @@ class DnsSdService {
public:
virtual ~DnsSdService() = default;
- // Creates a new DnsSdService instance, to be owned by the caller. On failure,
- // return nullptr.
- static std::unique_ptr<DnsSdService> Create(TaskRunner* task_runner);
-
// Returns the DnsSdQuerier owned by this DnsSdService. If queries are not
// supported, returns nullptr.
- virtual DnsSdQuerier* Querier() = 0;
+ virtual DnsSdQuerier* GetQuerier() = 0;
// Returns the DnsSdPublisher owned by this DnsSdService. If publishing is not
// supported, returns nullptr.
- virtual DnsSdPublisher* Publisher() = 0;
+ virtual DnsSdPublisher* GetPublisher() = 0;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc
index 546bfb7386a..c981a119537 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.cc
@@ -4,22 +4,53 @@
#include "discovery/dnssd/public/dns_sd_txt_record.h"
-#include "absl/strings/ascii.h"
+#include <cctype>
namespace openscreen {
namespace discovery {
+// static
+bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key,
+ const std::vector<uint8_t>& value) {
+ // The max length of any individual TXT record is 255 bytes.
+ if (key.size() + value.size() + 1 /* for equals */ > 255) {
+ return false;
+ }
+
+ return IsKeyValid(key);
+}
+
+// static
+bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key, uint8_t value) {
+ return IsValidTxtValue(key, std::vector<uint8_t>{value});
+}
+
+// static
+bool DnsSdTxtRecord::IsValidTxtValue(const std::string& key,
+ const std::string& value) {
+ return IsValidTxtValue(key, std::vector<uint8_t>(value.begin(), value.end()));
+}
+
Error DnsSdTxtRecord::SetValue(const std::string& key,
- const absl::Span<const uint8_t>& value) {
- if (!IsKeyValuePairValid(key, value)) {
+ std::vector<uint8_t> value) {
+ if (!IsValidTxtValue(key, value)) {
return Error::Code::kParameterInvalid;
}
- key_value_txt_[key] = std::vector<uint8_t>(value.begin(), value.end());
+ key_value_txt_[key] = std::move(value);
ClearFlag(key);
return Error::None();
}
+Error DnsSdTxtRecord::SetValue(const std::string& key, uint8_t value) {
+ return SetValue(key, std::vector<uint8_t>{value});
+}
+
+Error DnsSdTxtRecord::SetValue(const std::string& key,
+ const std::string& value) {
+ return SetValue(key, std::vector<uint8_t>(value.begin(), value.end()));
+}
+
Error DnsSdTxtRecord::SetFlag(const std::string& key, bool value) {
if (!IsKeyValid(key)) {
return Error::Code::kParameterInvalid;
@@ -34,7 +65,7 @@ Error DnsSdTxtRecord::SetFlag(const std::string& key, bool value) {
return Error::None();
}
-ErrorOr<absl::Span<const uint8_t>> DnsSdTxtRecord::GetValue(
+ErrorOr<DnsSdTxtRecord::ValueRef> DnsSdTxtRecord::GetValue(
const std::string& key) const {
if (!IsKeyValid(key)) {
return Error::Code::kParameterInvalid;
@@ -42,7 +73,7 @@ ErrorOr<absl::Span<const uint8_t>> DnsSdTxtRecord::GetValue(
auto it = key_value_txt_.find(key);
if (it != key_value_txt_.end()) {
- return absl::Span<const uint8_t>(it->second.data(), it->second.size());
+ return std::cref(it->second);
}
return Error::Code::kItemNotFound;
@@ -74,7 +105,8 @@ Error DnsSdTxtRecord::ClearFlag(const std::string& key) {
return Error::None();
}
-bool DnsSdTxtRecord::IsKeyValid(const std::string& key) const {
+// static
+bool DnsSdTxtRecord::IsKeyValid(const std::string& key) {
// The max length of any individual TXT record is 255 bytes.
if (key.size() > 255) {
return false;
@@ -115,22 +147,23 @@ std::vector<std::vector<uint8_t>> DnsSdTxtRecord::GetData() const {
return data;
}
-bool DnsSdTxtRecord::IsKeyValuePairValid(
- const std::string& key,
- const absl::Span<const uint8_t>& value) const {
- // The max length of any individual TXT record is 255 bytes.
- if (key.size() + value.size() + 1 /* for equals */ > 255) {
- return false;
- }
-
- return IsKeyValid(key);
-}
-
bool DnsSdTxtRecord::CaseInsensitiveComparison::operator()(
const std::string& lhs,
const std::string& rhs) const {
- return std::less<std::string>()(absl::AsciiStrToLower(lhs),
- absl::AsciiStrToLower(rhs));
+ if (lhs.size() != rhs.size()) {
+ return lhs < rhs;
+ }
+
+ for (size_t i = 0; i < lhs.size(); i++) {
+ int lhs_char = tolower(lhs[i]);
+ int rhs_char = tolower(rhs[i]);
+
+ if (lhs_char != rhs_char) {
+ return lhs_char < rhs_char;
+ }
+ }
+
+ return false;
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h
index 0fcfdf25878..b589d68bc35 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record.h
@@ -5,14 +5,12 @@
#ifndef DISCOVERY_DNSSD_PUBLIC_DNS_SD_TXT_RECORD_H_
#define DISCOVERY_DNSSD_PUBLIC_DNS_SD_TXT_RECORD_H_
+#include <functional>
#include <map>
#include <set>
#include <string>
#include <vector>
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
#include "platform/base/error.h"
namespace openscreen {
@@ -20,14 +18,23 @@ namespace discovery {
class DnsSdTxtRecord {
public:
+ using ValueRef = std::reference_wrapper<const std::vector<uint8_t>>;
+
+ // Returns whether the provided key value pair is valid for a TXT record.
+ static bool IsValidTxtValue(const std::string& key,
+ const std::vector<uint8_t>& value);
+ static bool IsValidTxtValue(const std::string& key, const std::string& value);
+ static bool IsValidTxtValue(const std::string& key, uint8_t value);
+
// Sets the value currently stored in this DNS-SD TXT record. Returns error
// if the provided key is already set or if either the key or value is
// invalid, and Error::None() otherwise. Keys are case-insensitive. Setting a
// value or flag which was already set will overwrite the previous one, and
// setting a value with a key which was previously associated with a flag
// erases the flag's value and vice versa.
- Error SetValue(const std::string& key,
- const absl::Span<const uint8_t>& value);
+ Error SetValue(const std::string& key, std::vector<uint8_t> value);
+ Error SetValue(const std::string& key, uint8_t value);
+ Error SetValue(const std::string& key, const std::string& value);
Error SetFlag(const std::string& key, bool value);
// Reads the value associated with the provided key, or an error if the key
@@ -36,7 +43,7 @@ class DnsSdTxtRecord {
// NOTE: If GetValue is called on a key assigned to a flag, an ItemNotFound
// error will be returned. If GetFlag is called on a key assigned to a value,
// 'false' will be returned.
- ErrorOr<absl::Span<const uint8_t>> GetValue(const std::string& key) const;
+ ErrorOr<ValueRef> GetValue(const std::string& key) const;
ErrorOr<bool> GetFlag(const std::string& key) const;
// Clears an existing TxtRecord value associated with the given key. If the
@@ -58,24 +65,13 @@ class DnsSdTxtRecord {
// quotes).
std::vector<std::vector<uint8_t>> GetData() const;
- inline bool operator==(const DnsSdTxtRecord& other) const {
- return key_value_txt_ == other.key_value_txt_ &&
- boolean_txt_ == other.boolean_txt_;
- }
-
- inline bool operator!=(const DnsSdTxtRecord& other) const {
- return !(*this == other);
- }
-
private:
struct CaseInsensitiveComparison {
bool operator()(const std::string& lhs, const std::string& rhs) const;
};
// Validations for keys and (key, value) pairs.
- bool IsKeyValid(const std::string& key) const;
- bool IsKeyValuePairValid(const std::string& key,
- const absl::Span<const uint8_t>& value) const;
+ static bool IsKeyValid(const std::string& key);
// Set of (key, value) pairs associated with this TXT record.
// NOTE: The same string name can only occur in one of key_value_txt_,
@@ -88,8 +84,38 @@ class DnsSdTxtRecord {
// NOTE: The same string name can only occur in one of key_value_txt_,
// boolean_txt_.
std::set<std::string, CaseInsensitiveComparison> boolean_txt_;
+
+ friend bool operator<(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs);
};
+inline bool operator<(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ if (lhs.boolean_txt_ != rhs.boolean_txt_) {
+ return lhs.boolean_txt_ < rhs.boolean_txt_;
+ }
+
+ return lhs.key_value_txt_ < rhs.key_value_txt_;
+}
+
+inline bool operator>(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ return rhs < lhs;
+}
+
+inline bool operator<=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ return !(rhs > lhs);
+}
+
+inline bool operator>=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ return !(rhs < lhs);
+}
+
+inline bool operator==(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ return lhs <= rhs && lhs >= rhs;
+}
+
+inline bool operator!=(const DnsSdTxtRecord& lhs, const DnsSdTxtRecord& rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc
index 4951bc19f42..f3239ca3d00 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_txt_record_unittest.cc
@@ -13,43 +13,64 @@ namespace dnssd {
TEST(TxtRecordTest, TestCaseInsensitivity) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
EXPECT_TRUE(txt.GetValue("KEY").is_value());
EXPECT_TRUE(txt.SetFlag("KEY2", true).ok());
- EXPECT_TRUE(txt.GetFlag("key2").is_value());
+ ASSERT_TRUE(txt.GetFlag("key2").is_value());
EXPECT_TRUE(txt.GetFlag("key2").value());
}
TEST(TxtRecordTest, TestEmptyValue) {
DnsSdTxtRecord txt;
- EXPECT_TRUE(txt.SetValue("key", {}).ok());
- EXPECT_TRUE(txt.GetValue("key").is_value());
- EXPECT_EQ(txt.GetValue("key").value().size(), size_t{0});
+ EXPECT_TRUE(txt.SetValue("key", std::vector<uint8_t>{}).ok());
+ ASSERT_TRUE(txt.GetValue("key").is_value());
+ EXPECT_EQ(txt.GetValue("key").value().get().size(), size_t{0});
+
+ EXPECT_TRUE(txt.SetValue("key2", "").ok());
+ ASSERT_TRUE(txt.GetValue("key2").is_value());
+ EXPECT_EQ(txt.GetValue("key2").value().get().size(), size_t{0});
}
TEST(TxtRecordTest, TestSetAndGetValue) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
- EXPECT_TRUE(txt.GetValue("key").is_value());
- EXPECT_EQ(txt.GetValue("key").value().size(), size_t{3});
- EXPECT_EQ(txt.GetValue("key").value()[0], 'a');
- EXPECT_EQ(txt.GetValue("key").value()[1], 'b');
- EXPECT_EQ(txt.GetValue("key").value()[2], 'c');
-
- uint8_t data2[]{'a', 'b'};
+ ASSERT_TRUE(txt.GetValue("key").is_value());
+ const std::vector<uint8_t>& value = txt.GetValue("key").value();
+ ASSERT_EQ(value.size(), size_t{3});
+ EXPECT_EQ(value[0], 'a');
+ EXPECT_EQ(value[1], 'b');
+ EXPECT_EQ(value[2], 'c');
+
+ std::vector<uint8_t> data2{'a', 'b'};
EXPECT_TRUE(txt.SetValue("key", data2).ok());
- EXPECT_TRUE(txt.GetValue("key").is_value());
- EXPECT_EQ(txt.GetValue("key").value().size(), size_t{2});
- EXPECT_EQ(txt.GetValue("key").value()[0], 'a');
- EXPECT_EQ(txt.GetValue("key").value()[1], 'b');
+ ASSERT_TRUE(txt.GetValue("key").is_value());
+ const std::vector<uint8_t>& value2 = txt.GetValue("key").value();
+ EXPECT_EQ(value2.size(), size_t{2});
+ EXPECT_EQ(value2[0], 'a');
+ EXPECT_EQ(value2[1], 'b');
+
+ EXPECT_TRUE(txt.SetValue("key", "abc").ok());
+ ASSERT_TRUE(txt.GetValue("key").is_value());
+ const std::vector<uint8_t>& value3 = txt.GetValue("key").value();
+ ASSERT_EQ(value.size(), size_t{3});
+ EXPECT_EQ(value3[0], 'a');
+ EXPECT_EQ(value3[1], 'b');
+ EXPECT_EQ(value3[2], 'c');
+
+ EXPECT_TRUE(txt.SetValue("key", "ab").ok());
+ ASSERT_TRUE(txt.GetValue("key").is_value());
+ const std::vector<uint8_t>& value4 = txt.GetValue("key").value();
+ EXPECT_EQ(value4.size(), size_t{2});
+ EXPECT_EQ(value4[0], 'a');
+ EXPECT_EQ(value4[1], 'b');
}
TEST(TxtRecordTest, TestClearValue) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
txt.ClearValue("key");
@@ -59,11 +80,11 @@ TEST(TxtRecordTest, TestClearValue) {
TEST(TxtRecordTest, TestSetAndGetFlag) {
DnsSdTxtRecord txt;
EXPECT_TRUE(txt.SetFlag("key", true).ok());
- EXPECT_TRUE(txt.GetFlag("key").is_value());
+ ASSERT_TRUE(txt.GetFlag("key").is_value());
EXPECT_TRUE(txt.GetFlag("key").value());
EXPECT_TRUE(txt.SetFlag("key", false).ok());
- EXPECT_TRUE(txt.GetFlag("key").is_value());
+ ASSERT_TRUE(txt.GetFlag("key").is_value());
EXPECT_FALSE(txt.GetFlag("key").value());
}
@@ -72,12 +93,13 @@ TEST(TxtRecordTest, TestClearFlag) {
EXPECT_TRUE(txt.SetFlag("key", true).ok());
txt.ClearFlag("key");
+ ASSERT_TRUE(txt.GetFlag("key").is_value());
EXPECT_FALSE(txt.GetFlag("key").value());
}
TEST(TxtRecordTest, TestGettingWrongRecordTypeFails) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
EXPECT_TRUE(txt.SetFlag("key2", true).ok());
EXPECT_FALSE(txt.GetValue("key2").is_value());
@@ -85,14 +107,14 @@ TEST(TxtRecordTest, TestGettingWrongRecordTypeFails) {
TEST(TxtRecordTest, TestClearWrongRecordTypeFails) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
EXPECT_TRUE(txt.SetFlag("key2", true).ok());
}
TEST(TxtRecordTest, TestGetDataWorks) {
DnsSdTxtRecord txt;
- uint8_t data[]{'a', 'b', 'c'};
+ std::vector<uint8_t> data{'a', 'b', 'c'};
EXPECT_TRUE(txt.SetValue("key", data).ok());
EXPECT_TRUE(txt.SetFlag("bool", true).ok());
std::vector<std::vector<uint8_t>> results = txt.GetData();
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS b/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS
new file mode 100644
index 00000000000..62d4cc14c82
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/DEPS
@@ -0,0 +1,7 @@
+# -*- Mode: Python; -*-
+
+include_rules = [
+ '+discovery/dnssd/public',
+ '+discovery/dnssd/impl',
+ '+discovery/mdns',
+]
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc
index 462f949f853..61069284184 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.cc
@@ -4,6 +4,8 @@
#include "discovery/dnssd/testing/fake_dns_record_factory.h"
+#include <utility>
+
namespace openscreen {
namespace discovery {
@@ -19,38 +21,29 @@ MdnsRecord FakeDnsRecordFactory::CreateFullyPopulatedSrvRecord(uint16_t port) {
}
// static
-const IPAddress FakeDnsRecordFactory::kV4Address = IPAddress(192, 168, 0, 0);
-
-// static
-const IPAddress FakeDnsRecordFactory::kV6Address =
- IPAddress(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
-
-// static
-const IPEndpoint FakeDnsRecordFactory::kV4Endpoint{
- FakeDnsRecordFactory::kV4Address, FakeDnsRecordFactory::kPortNum};
+constexpr uint16_t FakeDnsRecordFactory::kPortNum;
// static
-const IPEndpoint FakeDnsRecordFactory::kV6Endpoint{
- FakeDnsRecordFactory::kV6Address, FakeDnsRecordFactory::kPortNum};
+const uint8_t FakeDnsRecordFactory::kV4AddressOctets[4] = {192, 168, 0, 0};
// static
-const std::string FakeDnsRecordFactory::kInstanceName = "instance";
+const uint16_t FakeDnsRecordFactory::kV6AddressHextets[8] = {
+ 0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10};
// static
-const std::string FakeDnsRecordFactory::kServiceName = "_srv-name._udp";
+const char FakeDnsRecordFactory::kInstanceName[] = "instance";
// static
-const std::string FakeDnsRecordFactory::kServiceNameProtocolPart = "_udp";
+const char FakeDnsRecordFactory::kServiceName[] = "_srv-name._udp";
// static
-const std::string FakeDnsRecordFactory::kServiceNameServicePart = "_srv-name";
+const char FakeDnsRecordFactory::kServiceNameProtocolPart[] = "_udp";
// static
-const std::string FakeDnsRecordFactory::kDomainName = "local";
+const char FakeDnsRecordFactory::kServiceNameServicePart[] = "_srv-name";
// static
-const InstanceKey FakeDnsRecordFactory::kKey =
- InstanceKey(kInstanceName, kServiceName, kDomainName);
+const char FakeDnsRecordFactory::kDomainName[] = "local";
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h
index f8bddcdb1ff..473ec68ebc1 100644
--- a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h
+++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h
@@ -5,11 +5,11 @@
#ifndef DISCOVERY_DNSSD_TESTING_FAKE_DNS_RECORD_FACTORY_H_
#define DISCOVERY_DNSSD_TESTING_FAKE_DNS_RECORD_FACTORY_H_
-#include <chrono>
-#include <string>
+#include <stdint.h>
+
+#include <chrono> // NOLINT
#include "discovery/dnssd/impl/constants.h"
-#include "discovery/dnssd/impl/instance_key.h"
#include "discovery/mdns/mdns_records.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -20,16 +20,13 @@ namespace discovery {
class FakeDnsRecordFactory {
public:
static constexpr uint16_t kPortNum = 80;
- static const IPAddress kV4Address;
- static const IPAddress kV6Address;
- static const IPEndpoint kV4Endpoint;
- static const IPEndpoint kV6Endpoint;
- static const std::string kInstanceName;
- static const std::string kServiceName;
- static const std::string kServiceNameProtocolPart;
- static const std::string kServiceNameServicePart;
- static const std::string kDomainName;
- static const InstanceKey kKey;
+ static const uint8_t kV4AddressOctets[4];
+ static const uint16_t kV6AddressHextets[8];
+ static const char kInstanceName[];
+ static const char kServiceName[];
+ static const char kServiceNameProtocolPart[];
+ static const char kServiceNameServicePart[];
+ static const char kDomainName[];
static MdnsRecord CreateFullyPopulatedSrvRecord(uint16_t port = kPortNum);
};
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/DEPS b/chromium/third_party/openscreen/src/discovery/mdns/DEPS
new file mode 100644
index 00000000000..309d03f4532
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/DEPS
@@ -0,0 +1,5 @@
+# -*- Mode: Python; -*-
+
+include_rules = [
+ '+discovery/mdns/public',
+]
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin
new file mode 100644
index 00000000000..24ae31b1c40
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_answer.bin
@@ -0,0 +1,20 @@
+0010 0000 0000 0060 0000 0000 4047 5637
+4780 f537 5627 6796 3656 40f5 4736 0750
+c6f6 3616 c600 00f2 0810 0000 0050 0052
+5047 5637 4723 90f5 3756 2767 9636 5623
+40f5 4736 0760 c6f6 3616 c623 0000 6000
+8000 0004 1080 4756 3737 1646 6647 a0f5
+3756 2767 9636 1646 560c a100 1208 1000
+0000 5000 3200 1000 2000 3030 e656 77b0
+f5e6 5677 3756 2767 9636 5640 f557 4607
+60c6 f636 16c6 4300 7047 5667 6637 3747
+a0f5 3756 2767 9636 5616 370c a100 0108
+1000 0000 5000 b140 4756 3747 9026 2756
+1646 d3e6 f677 b086 56c6 c6f6 e277 f627
+c646 7047 5637 4666 7647 a0f5 3756 2767
+1637 9636 560c a100 1008 ff00 0000 5000
+400c 8a10 100c 2d00 c108 1000 0000 5000
+0100 1000 2000 3000 4000 5000 6000 7000
+8070 4756 4666 7637 47b0 f537 5627 3637
+1667 9636 560c a100 c008 ff00 0000 5000
+200c b7
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin
new file mode 100644
index 00000000000..2fffbdf6bfc
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/multi_question.bin
@@ -0,0 +1,7 @@
+0010 0000 0030 0000 0000 0000 4047 5637
+4780 f537 5627 6796 3656 40f5 4736 0750
+c6f6 3616 c600 00ff 0010 5047 5637 4723
+90f5 3756 2767 9636 5623 40f5 4736 0760
+c6f6 3616 c623 0000 ff00 1050 4756 3747
+3390 f537 5627 6796 3656 3340 f557 4607
+60c6 f636 16c6 3300 00ff 0010
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin
new file mode 100644
index 00000000000..5792536a579
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/probe.bin
@@ -0,0 +1,10 @@
+0010 0000 0010 0000 0050 0000 4047 5637
+4780 f537 5627 6796 3656 40f5 4736 0750
+c6f6 3616 c600 00ff 08ff 0cc0 0012 0810
+0000 0050 0032 0010 0020 0030 30e6 5677
+b0f5 e656 7737 5627 6796 3656 40f5 5746
+0760 c6f6 3616 c643 000c c000 0108 1000
+0000 5000 1000 0cc0 0010 08ff 0000 0050
+0040 0c8a 1010 0cc0 00c1 0810 0000 0050
+0001 0010 0020 0030 0040 0050 0060 0070
+0080 0cc0 00c0 08ff 0000 0050 0020 0cc3
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin
new file mode 100644
index 00000000000..19a5727536b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/fuzzer_seeds/ptr_response.bin
@@ -0,0 +1,8 @@
+0010 0000 0000 0010 0000 0040 80f5 3756
+2767 9636 5640 f547 3607 50c6 f636 16c6
+0000 c008 ff00 0000 5000 7040 4756 3747
+0cc0 0cb2 0012 0810 0000 0050 0080 0010
+0020 0030 0cb2 0cb2 0001 0810 0000 0050
+0010 000c b200 1008 ff00 0000 5000 400c
+8a10 100c b200 c108 1000 0000 5000 0100
+1000 2000 3000 4000 5000 6000 7000 80
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h
new file mode 100644
index 00000000000..4585ff65a18
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_domain_confirmed_provider.h
@@ -0,0 +1,28 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_
+#define DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_
+
+#include "discovery/mdns/mdns_records.h"
+
+namespace openscreen {
+namespace discovery {
+
+class MdnsDomainConfirmedProvider {
+ public:
+ virtual ~MdnsDomainConfirmedProvider() = default;
+
+ // Called once the probing phase has been completed, and a DomainName has
+ // been confirmed. The callee is expected to register records for the
+ // newly confirmed name in this callback. Note that the requested name and
+ // the confirmed name may differ if conflict resolution has occurred.
+ virtual void OnDomainFound(const DomainName& requested_name,
+ const DomainName& confirmed_name) = 0;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_MDNS_MDNS_DOMAIN_CONFIRMED_PROVIDER_H_
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc
new file mode 100644
index 00000000000..dc911043b0c
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc
@@ -0,0 +1,113 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_probe.h"
+
+#include <utility>
+
+#include "discovery/mdns/mdns_random.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "discovery/mdns/public/mdns_constants.h"
+#include "platform/api/task_runner.h"
+#include "platform/api/time.h"
+
+namespace openscreen {
+namespace discovery {
+
+MdnsProbe::MdnsProbe(DomainName target_name, IPAddress address)
+ : target_name_(std::move(target_name)),
+ address_(std::move(address)),
+ address_record_(CreateAddressRecord(target_name_, address_)) {}
+
+MdnsProbe::~MdnsProbe() = default;
+
+MdnsProbeImpl::Observer::~Observer() = default;
+
+MdnsProbeImpl::MdnsProbeImpl(MdnsSender* sender,
+ MdnsReceiver* receiver,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ Observer* observer,
+ DomainName target_name,
+ IPAddress address)
+ : MdnsProbe(std::move(target_name), std::move(address)),
+ random_delay_(random_delay),
+ task_runner_(task_runner),
+ now_function_(now_function),
+ alarm_(now_function_, task_runner_),
+ sender_(sender),
+ receiver_(receiver),
+ observer_(observer) {
+ OSP_DCHECK(sender_);
+ OSP_DCHECK(receiver_);
+ OSP_DCHECK(random_delay_);
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(observer_);
+
+ receiver_->AddResponseCallback(this);
+ alarm_.ScheduleFromNow([this]() { ProbeOnce(); },
+ random_delay_->GetInitialProbeDelay());
+}
+
+MdnsProbeImpl::~MdnsProbeImpl() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ Stop();
+}
+
+void MdnsProbeImpl::ProbeOnce() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (successful_probe_queries_++ < kProbeIterationCountBeforeSuccess) {
+ // MdnsQuerier cannot be used, because probe queries cannot use the cache,
+ // so instead directly send the query through the MdnsSender.
+ MdnsMessage probe_query(CreateMessageId(), MessageType::Query);
+ MdnsQuestion probe_question(target_name(), DnsType::kANY, DnsClass::kIN,
+ ResponseType::kUnicast);
+ probe_query.AddQuestion(probe_question);
+ probe_query.AddAuthorityRecord(address_record());
+ sender_->SendMulticast(probe_query);
+
+ alarm_.ScheduleFromNow([this]() { ProbeOnce(); },
+ kDelayBetweenProbeQueries);
+ } else {
+ Stop();
+ observer_->OnProbeSuccess(this);
+ }
+}
+
+void MdnsProbeImpl::Stop() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (is_running_) {
+ alarm_.Cancel();
+ receiver_->RemoveResponseCallback(this);
+ is_running_ = false;
+ }
+}
+
+void MdnsProbeImpl::Postpone(std::chrono::seconds delay) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ successful_probe_queries_ = 0;
+ alarm_.Cancel();
+ alarm_.ScheduleFromNow([this]() { ProbeOnce(); },
+ std::chrono::duration_cast<Clock::duration>(delay));
+}
+
+void MdnsProbeImpl::OnMessageReceived(const MdnsMessage& message) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+ OSP_DCHECK(message.type() == MessageType::Response);
+
+ for (const auto& record : message.answers()) {
+ if (record.name() == target_name()) {
+ Stop();
+ observer_->OnProbeFailure(this);
+ }
+ }
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h
new file mode 100644
index 00000000000..d918266d10d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.h
@@ -0,0 +1,129 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_MDNS_MDNS_PROBE_H_
+#define DISCOVERY_MDNS_MDNS_PROBE_H_
+
+#include <vector>
+
+#include "discovery/mdns/mdns_receiver.h"
+#include "discovery/mdns/mdns_records.h"
+#include "platform/api/time.h"
+#include "platform/base/ip_address.h"
+#include "util/alarm.h"
+
+namespace openscreen {
+
+class TaskRunner;
+
+namespace discovery {
+
+class MdnsQuerier;
+class MdnsRandom;
+class MdnsSender;
+
+// Implements the probing method as described in RFC 6762 section 8.1 to claim a
+// provided domain name. In place of the MdnsRecord(s) that will be published, a
+// 'fake' mDNS record of type A or AAAA will be generated from provided endpoint
+// variable with TTL 2 seconds. 0 or 1 seconds are not used because these
+// constants are used as part of goodbye records, so poorly written receivers
+// may handle these cases in unexpected ways. Caching of probe queries is not
+// supported for mDNS probes (else, in a probe which failed, invalid records
+// would be cached). If for some reason this did occur, though, it should be a
+// non-issue because the probe record will expire after 2 seconds.
+//
+// During probe query conflict resolution, these fake records will be compared
+// with the records provided by another mDNS endpoint. As 2 different mDNS
+// endpoints of the same service type cannot have the same endpoint, these
+// fake mDNS records should never match the real or fake records provided by
+// the other mDNS endpoint, so lexicographic comparison as described in RFC
+// 6762 section 8.2.1 can proceed as described.
+class MdnsProbe : public MdnsReceiver::ResponseClient {
+ public:
+ // The observer class is responsible for returning the result of an ongoing
+ // probe query to the caller.
+ class Observer {
+ public:
+ virtual ~Observer();
+
+ // Called once the probing phase has been completed successfully. |probe| is
+ // expected to be stopped at the time of this call.
+ virtual void OnProbeSuccess(MdnsProbe* probe) = 0;
+
+ // Called once the probing phase fails. |probe| is expected to be stopped at
+ // the time of this call.
+ virtual void OnProbeFailure(MdnsProbe* probe) = 0;
+ };
+
+ MdnsProbe(DomainName target_name, IPAddress address);
+ virtual ~MdnsProbe();
+
+ // Postpones the current probe operation by |delay|, after which the probing
+ // process is re-initialized.
+ virtual void Postpone(std::chrono::seconds delay) = 0;
+
+ const DomainName& target_name() const { return target_name_; }
+ const IPAddress& address() const { return address_; }
+ const MdnsRecord address_record() const { return address_record_; }
+
+ private:
+ const DomainName target_name_;
+ const IPAddress address_;
+ const MdnsRecord address_record_;
+};
+
+class MdnsProbeImpl : public MdnsProbe {
+ public:
+ // |sender|, |receiver|, |random_delay|, |task_runner|, and |observer| must
+ // all persist for the duration of this object's lifetime.
+ MdnsProbeImpl(MdnsSender* sender,
+ MdnsReceiver* receiver,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ Observer* observer,
+ DomainName target_name,
+ IPAddress address);
+ MdnsProbeImpl(const MdnsProbeImpl& other) = delete;
+ MdnsProbeImpl(MdnsProbeImpl&& other) = delete;
+ ~MdnsProbeImpl() override;
+
+ MdnsProbeImpl& operator=(const MdnsProbeImpl& other) = delete;
+ MdnsProbeImpl& operator=(MdnsProbeImpl&& other) = delete;
+
+ // MdnsProbe overrides.
+ void Postpone(std::chrono::seconds delay) override;
+
+ private:
+ friend class MdnsProbeTests;
+
+ // Performs the probe query as described in the class-level comment.
+ void ProbeOnce();
+
+ // Stops this probe.
+ void Stop();
+
+ // MdnsReceiver::ResponseClient overrides.
+ void OnMessageReceived(const MdnsMessage& message) override;
+
+ MdnsRandom* const random_delay_;
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+
+ Alarm alarm_;
+
+ // NOTE: Access to all below variables should only be done from the task
+ // runner thread.
+ MdnsSender* const sender_;
+ MdnsReceiver* const receiver_;
+ Observer* const observer_;
+
+ int successful_probe_queries_ = 0;
+ bool is_running_ = true;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_MDNS_MDNS_PROBE_H_
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc
new file mode 100644
index 00000000000..7ac93c72900
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.cc
@@ -0,0 +1,255 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_probe_manager.h"
+
+#include <string>
+#include <utility>
+
+#include "discovery/mdns/mdns_sender.h"
+#include "platform/api/task_runner.h"
+
+namespace openscreen {
+namespace discovery {
+namespace {
+
+// The timespan by which to delay subsequent mDNS Probe queries for the same
+// domain name when a simultaneous query from another host is detected, as
+// described in RFC 6762 section 8.2
+constexpr std::chrono::seconds kSimultaneousProbeDelay =
+ std::chrono::seconds(1);
+
+DomainName CreateRetryDomainName(const DomainName& name, int attempt) {
+ OSP_DCHECK(name.labels().size());
+ std::vector<std::string> labels = name.labels();
+ std::string& label = labels[0];
+ std::string attempts_str = std::to_string(attempt);
+ if (label.size() + attempts_str.size() >= kMaxLabelLength) {
+ label = label.substr(0, kMaxLabelLength - attempts_str.size());
+ }
+ label.append(attempts_str);
+
+ return DomainName(std::move(labels));
+}
+
+} // namespace
+
+MdnsProbeManager::~MdnsProbeManager() = default;
+
+MdnsProbeManagerImpl::MdnsProbeManagerImpl(MdnsSender* sender,
+ MdnsReceiver* receiver,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function)
+ : sender_(sender),
+ receiver_(receiver),
+ random_delay_(random_delay),
+ task_runner_(task_runner),
+ now_function_(now_function) {
+ OSP_DCHECK(sender_);
+ OSP_DCHECK(receiver_);
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(random_delay_);
+}
+
+MdnsProbeManagerImpl::~MdnsProbeManagerImpl() = default;
+
+Error MdnsProbeManagerImpl::StartProbe(MdnsDomainConfirmedProvider* callback,
+ DomainName requested_name,
+ IPAddress address) {
+ // Check if |requested_name| is already being queried for.
+ if (FindOngoingProbe(requested_name) != ongoing_probes_.end()) {
+ return Error::Code::kOperationInProgress;
+ }
+
+ // Check if |requested_name| is already claimed.
+ if (IsDomainClaimed(requested_name)) {
+ return Error::Code::kItemAlreadyExists;
+ }
+
+ OSP_DVLOG << "Starting new mDNS Probe for domain '"
+ << requested_name.ToString() << "'";
+
+ // Begin a new probe.
+ auto probe = CreateProbe(requested_name, std::move(address));
+ ongoing_probes_.emplace_back(std::move(probe), std::move(requested_name),
+ callback);
+ return Error::None();
+}
+
+Error MdnsProbeManagerImpl::StopProbe(const DomainName& requested_name) {
+ auto it = FindOngoingProbe(requested_name);
+ if (it == ongoing_probes_.end()) {
+ return Error::Code::kItemNotFound;
+ }
+
+ ongoing_probes_.erase(it);
+ return Error::None();
+}
+
+bool MdnsProbeManagerImpl::IsDomainClaimed(const DomainName& domain) const {
+ return FindCompletedProbe(domain) != completed_probes_.end();
+}
+
+void MdnsProbeManagerImpl::RespondToProbeQuery(const MdnsMessage& message,
+ const IPEndpoint& src) {
+ OSP_DCHECK(!message.questions().empty());
+
+ const std::vector<MdnsQuestion>& questions = message.questions();
+ MdnsMessage send_message(CreateMessageId(), MessageType::Response);
+
+ // Iterate across all questions asked and all completed probes and add A or
+ // AAAA records associated with the endpoints for which the names match.
+ // |questions| is expected to be of size 1 and |completed_probes| should be
+ // small (generally size 1), so this should be fast.
+ for (const auto& question : questions) {
+ for (auto it = completed_probes_.begin(); it != completed_probes_.end();
+ it++) {
+ if (question.name() == (*it)->target_name()) {
+ send_message.AddAnswer((*it)->address_record());
+ break;
+ }
+ }
+ }
+
+ if (!send_message.answers().empty()) {
+ sender_->SendMessage(send_message, src);
+ } else {
+ // If the name isn't already claimed, check to see if a probe is ongoing. If
+ // so, compare the address record for that probe with the one in the
+ // received message and resolve as specified in RFC 6762 section 8.2.
+ TiebreakSimultaneousProbes(message);
+ }
+}
+
+void MdnsProbeManagerImpl::TiebreakSimultaneousProbes(
+ const MdnsMessage& message) {
+ OSP_DCHECK(!message.questions().empty());
+ OSP_DCHECK(!message.authority_records().empty());
+
+ for (const auto& question : message.questions()) {
+ for (auto it = ongoing_probes_.begin(); it != ongoing_probes_.end(); it++) {
+ if (it->probe->target_name() == question.name()) {
+ // When a host is probing for a set of records with the same name, or a
+ // message is received containing multiple tiebreaker records answering
+ // a given probe question in the Question Section, the host's records
+ // and the tiebreaker records from the message are each sorted into
+ // order, and then compared pairwise, using the same comparison
+ // technique described above, until a difference is found. Because the
+ // probe object is guaranteed to only have the address record, only the
+ // lowest authority record is needed.
+ auto lowest_record_it =
+ std::min_element(message.authority_records().begin(),
+ message.authority_records().end());
+
+ // If this host finds that its own data is lexicographically later, it
+ // simply ignores the other host's probe. The other host will have
+ // receive this host's probe simultaneously, and will reject its own
+ // probe through this same calculation.
+ const MdnsRecord& probe_record = it->probe->address_record();
+ if (probe_record > *lowest_record_it) {
+ break;
+ }
+
+ // If the probe query is only of size one and the record received is
+ // equal to this record, then the received query is the same as what
+ // this probe is sending out. In this case, nothing needs to be done.
+ if (message.authority_records().size() == 1 &&
+ !(probe_record < *lowest_record_it)) {
+ break;
+ }
+
+ // At this point, one of the following must be true:
+ // - The query's lowest record is greater than this probe's record
+ // - The query's lowest record equals this probe's record but it also
+ // has additional records.
+ // In either case, the query must take priority over this probe. This
+ // host defers to the winning host by waiting one second, and then
+ // begins probing for this record again. See RFC 6762 section 8.2 for
+ // the logic behind waiting one second.
+ it->probe->Postpone(kSimultaneousProbeDelay);
+ break;
+ }
+ }
+ }
+}
+
+void MdnsProbeManagerImpl::OnProbeSuccess(MdnsProbe* probe) {
+ auto it = FindOngoingProbe(probe);
+ if (it != ongoing_probes_.end()) {
+ DomainName target_name = it->probe->target_name();
+ completed_probes_.push_back(std::move(it->probe));
+ DomainName requested = std::move(it->requested_name);
+ MdnsDomainConfirmedProvider* callback = it->callback;
+ ongoing_probes_.erase(it);
+ callback->OnDomainFound(std::move(requested), std::move(target_name));
+ }
+}
+
+void MdnsProbeManagerImpl::OnProbeFailure(MdnsProbe* probe) {
+ auto ongoing_it = FindOngoingProbe(probe);
+ if (ongoing_it == ongoing_probes_.end()) {
+ // This means that the probe was canceled.
+ return;
+ }
+
+ OSP_DVLOG << "Probe for domain '"
+ << CreateRetryDomainName(ongoing_it->requested_name,
+ ongoing_it->num_probes_failed)
+ .ToString()
+ << "' failed. Trying new domain...";
+
+ // Create a new probe with a modified domain name.
+ DomainName new_name = CreateRetryDomainName(ongoing_it->requested_name,
+ ++ongoing_it->num_probes_failed);
+
+ // If this domain has already been claimed, skip ahead to knowing it's
+ // claimed.
+ auto completed_it = FindCompletedProbe(new_name);
+ if (completed_it != completed_probes_.end()) {
+ DomainName requested_name = std::move(ongoing_it->requested_name);
+ MdnsDomainConfirmedProvider* callback = ongoing_it->callback;
+ ongoing_probes_.erase(ongoing_it);
+ callback->OnDomainFound(requested_name, (*completed_it)->target_name());
+ } else {
+ std::unique_ptr<MdnsProbe> new_probe =
+ CreateProbe(std::move(new_name), ongoing_it->probe->address());
+ ongoing_it->probe = std::move(new_probe);
+ }
+}
+
+std::vector<std::unique_ptr<MdnsProbe>>::const_iterator
+MdnsProbeManagerImpl::FindCompletedProbe(const DomainName& name) const {
+ return std::find_if(completed_probes_.begin(), completed_probes_.end(),
+ [&name](const std::unique_ptr<MdnsProbe>& completed) {
+ return completed->target_name() == name;
+ });
+}
+
+std::vector<MdnsProbeManagerImpl::OngoingProbe>::iterator
+MdnsProbeManagerImpl::FindOngoingProbe(const DomainName& name) {
+ return std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
+ [&name](const OngoingProbe& ongoing) {
+ return ongoing.requested_name == name;
+ });
+}
+
+std::vector<MdnsProbeManagerImpl::OngoingProbe>::iterator
+MdnsProbeManagerImpl::FindOngoingProbe(MdnsProbe* probe) {
+ return std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
+ [&probe](const OngoingProbe& ongoing) {
+ return ongoing.probe.get() == probe;
+ });
+}
+
+MdnsProbeManagerImpl::OngoingProbe::OngoingProbe(
+ std::unique_ptr<MdnsProbe> probe,
+ DomainName name,
+ MdnsDomainConfirmedProvider* callback)
+ : probe(std::move(probe)),
+ requested_name(std::move(name)),
+ callback(callback) {}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h
new file mode 100644
index 00000000000..6e8584ff7e9
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h
@@ -0,0 +1,150 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_
+#define DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_
+
+#include <memory>
+#include <vector>
+
+#include "discovery/mdns/mdns_domain_confirmed_provider.h"
+#include "discovery/mdns/mdns_probe.h"
+#include "discovery/mdns/mdns_records.h"
+#include "platform/base/error.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+
+class TaskRunner;
+
+namespace discovery {
+
+class MdnsQuerier;
+class MdnsRandom;
+class MdnsSender;
+
+// Interface for maintaining ownership of mDNS Domains.
+class MdnsProbeManager {
+ public:
+ virtual ~MdnsProbeManager();
+
+ // Returns whether the provided domain name has been claimed as owned by this
+ // mDNS Probe Manager.
+ virtual bool IsDomainClaimed(const DomainName& domain) const = 0;
+
+ // |message| is a message received from another host which contains a query
+ // from some domain. It is a considered a probe query for a specific domain if
+ // it contains a query for a specific domain which is answered by mDNS Records
+ // in the 'authority records' section of |message|. If a probe for the
+ // provided domain name is ongoing, an MdnsMessage is sent to the provided
+ // endpoint as described in RFC 6762 section 8.2 to allow for conflict
+ // resolution. If the requested name has already been claimed, a message to
+ // specify this will be sent as described in RFC 6762 section 8.1. The |src|
+ // argument is the address from which the message was originally sent, so that
+ // the response message may be sent as a unicast response.
+ virtual void RespondToProbeQuery(const MdnsMessage& message,
+ const IPEndpoint& src) = 0;
+};
+
+// This class is responsible for managing all ongoing probes for claiming domain
+// names, as described in RFC 6762 Section 8.1's probing phase. If one such
+// probe fails due to a conflict detection, this class will modify the domain
+// name as described in RFC 6762 section 9 and re-initiate probing for the new
+// name.
+class MdnsProbeManagerImpl : public MdnsProbe::Observer,
+ public MdnsProbeManager {
+ public:
+ // |sender|, |receiver|, |random_delay|, and |task_runner|, must all persist
+ // for the duration of this object's lifetime.
+ MdnsProbeManagerImpl(MdnsSender* sender,
+ MdnsReceiver* receiver,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function);
+ MdnsProbeManagerImpl(const MdnsProbeManager& other) = delete;
+ MdnsProbeManagerImpl(MdnsProbeManager&& other) = delete;
+ ~MdnsProbeManagerImpl() override;
+
+ MdnsProbeManagerImpl& operator=(const MdnsProbeManagerImpl& other) = delete;
+ MdnsProbeManagerImpl& operator=(MdnsProbeManagerImpl&& other) = delete;
+
+ // Starts probing for a valid domain name based on the given one. This may
+ // only be called once per MdnsProbe instance. |observer| must persist until
+ // a valid domain is discovered and the observer's OnDomainFound method is
+ // called.
+ // NOTE: |address| is used to generate a 'fake' address record to use for the
+ // probe query. See MdnsProbe::PerformProbeIteration() for further details.
+ Error StartProbe(MdnsDomainConfirmedProvider* callback,
+ DomainName requested_name,
+ IPAddress address);
+
+ // Stops probing for the requested domain name.
+ Error StopProbe(const DomainName& requested_name);
+
+ // MdnsDomainOwnershipManager overrides.
+ bool IsDomainClaimed(const DomainName& domain) const override;
+ void RespondToProbeQuery(const MdnsMessage& message,
+ const IPEndpoint& src) override;
+
+ private:
+ friend class TestMdnsProbeManager;
+
+ // Resolves simultaneous probe queries as described in RFC 6762 section 8.2.
+ void TiebreakSimultaneousProbes(const MdnsMessage& message);
+
+ virtual std::unique_ptr<MdnsProbe> CreateProbe(DomainName name,
+ IPAddress address) {
+ return std::make_unique<MdnsProbeImpl>(sender_, receiver_, random_delay_,
+ task_runner_, now_function_, this,
+ std::move(name), std::move(address));
+ }
+
+ // Owns an in-progress MdnsProbe. When the probe starts, an instance of this
+ // struct is created. Upon successful completion of the probe, this instance
+ // is deleted and the owned |probe| instance is moved to |completed_probes|.
+ // Upon failure, the instance is updated with a new MdnsProbe object and this
+ // process is repeated.
+ struct OngoingProbe {
+ OngoingProbe(std::unique_ptr<MdnsProbe> probe,
+ DomainName name,
+ MdnsDomainConfirmedProvider* callback);
+
+ // NOTE: unique_ptr objects are used to avoid issues when the container
+ // holding this object is resized.
+ std::unique_ptr<MdnsProbe> probe;
+ DomainName requested_name;
+ MdnsDomainConfirmedProvider* callback;
+ int num_probes_failed = 0;
+ };
+
+ // MdnsProbe::Observer overrides.
+ void OnProbeSuccess(MdnsProbe* probe) override;
+ void OnProbeFailure(MdnsProbe* probe) override;
+
+ // Helpers to find ongoing and completed probes.
+ std::vector<std::unique_ptr<MdnsProbe>>::const_iterator FindCompletedProbe(
+ const DomainName& name) const;
+ std::vector<OngoingProbe>::iterator FindOngoingProbe(const DomainName& name);
+ std::vector<OngoingProbe>::iterator FindOngoingProbe(MdnsProbe* probe);
+
+ MdnsSender* const sender_;
+ MdnsReceiver* const receiver_;
+ MdnsRandom* const random_delay_;
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+
+ // The set of all probes which have completed successfully. This set is
+ // expected to remain small. unique_ptrs are used for storing the probes to
+ // avoid issues when the vector is resized.
+ std::vector<std::unique_ptr<MdnsProbe>> completed_probes_;
+
+ // The set of all currently ongoing probes. This set is expected to remain
+ // small.
+ std::vector<OngoingProbe> ongoing_probes_;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc
new file mode 100644
index 00000000000..1420e550c23
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager_unittest.cc
@@ -0,0 +1,355 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_probe_manager.h"
+
+#include <utility>
+
+#include "discovery/mdns/mdns_probe.h"
+#include "discovery/mdns/mdns_querier.h"
+#include "discovery/mdns/mdns_random.h"
+#include "discovery/mdns/mdns_receiver.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "platform/test/fake_udp_socket.h"
+
+using testing::_;
+using testing::Invoke;
+using testing::Return;
+using testing::StrictMock;
+
+namespace openscreen {
+namespace discovery {
+
+class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider {
+ public:
+ MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&));
+};
+
+class MockMdnsSender : public MdnsSender {
+ public:
+ explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
+
+ MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
+ MOCK_METHOD2(SendMessage,
+ Error(const MdnsMessage& message, const IPEndpoint& endpoint));
+};
+
+class MockMdnsProbe : public MdnsProbe {
+ public:
+ MockMdnsProbe(DomainName target_name, IPAddress address)
+ : MdnsProbe(std::move(target_name), std::move(address)) {}
+
+ MOCK_METHOD1(Postpone, void(std::chrono::seconds));
+ MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&));
+};
+
+class TestMdnsProbeManager : public MdnsProbeManagerImpl {
+ public:
+ using MdnsProbeManagerImpl::MdnsProbeManagerImpl;
+
+ using MdnsProbeManagerImpl::OnProbeFailure;
+ using MdnsProbeManagerImpl::OnProbeSuccess;
+
+ std::unique_ptr<MdnsProbe> CreateProbe(DomainName name,
+ IPAddress address) override {
+ return std::make_unique<StrictMock<MockMdnsProbe>>(std::move(name),
+ std::move(address));
+ }
+
+ StrictMock<MockMdnsProbe>* GetOngoingMockProbeByTarget(
+ const DomainName& target) {
+ const auto it =
+ std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
+ [&target](const OngoingProbe& ongoing) {
+ return ongoing.probe->target_name() == target;
+ });
+ if (it != ongoing_probes_.end()) {
+ return static_cast<StrictMock<MockMdnsProbe>*>(it->probe.get());
+ }
+ return nullptr;
+ }
+
+ StrictMock<MockMdnsProbe>* GetCompletedMockProbe(const DomainName& target) {
+ const auto it = FindCompletedProbe(target);
+ if (it != completed_probes_.end()) {
+ return static_cast<StrictMock<MockMdnsProbe>*>(it->get());
+ }
+ return nullptr;
+ }
+
+ bool HasOngoingProbe(const DomainName& target) {
+ return GetOngoingMockProbeByTarget(target) != nullptr;
+ }
+
+ bool HasCompletedProbe(const DomainName& target) {
+ return GetCompletedMockProbe(target) != nullptr;
+ }
+
+ size_t GetOngoingProbeCount() { return ongoing_probes_.size(); }
+
+ size_t GetCompletedProbeCount() { return completed_probes_.size(); }
+};
+
+class MdnsProbeManagerTests : public testing::Test {
+ public:
+ MdnsProbeManagerTests()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ socket_(&task_runner_),
+ sender_(&socket_),
+ manager_(&sender_,
+ &receiver_,
+ &random_,
+ &task_runner_,
+ FakeClock::now) {
+ ExpectProbeStopped(name_);
+ ExpectProbeStopped(name2_);
+ ExpectProbeStopped(name_retry_);
+ }
+
+ protected:
+ MdnsMessage CreateProbeQueryMessage(DomainName domain,
+ const IPAddress& address) {
+ MdnsMessage message(CreateMessageId(), MessageType::Query);
+ MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY,
+ ResponseType::kUnicast);
+ MdnsRecord record = CreateAddressRecord(std::move(domain), address);
+ message.AddQuestion(std::move(question));
+ message.AddAuthorityRecord(std::move(record));
+ return message;
+ }
+
+ void ExpectProbeStopped(const DomainName& name) {
+ EXPECT_FALSE(manager_.HasOngoingProbe(name));
+ EXPECT_FALSE(manager_.HasCompletedProbe(name));
+ EXPECT_FALSE(manager_.IsDomainClaimed(name));
+ }
+
+ StrictMock<MockMdnsProbe>* ExpectProbeOngoing(const DomainName& name) {
+ // Get around limitations of using an assert in a function with a return
+ // value.
+ auto validate = [this, &name]() {
+ ASSERT_TRUE(manager_.HasOngoingProbe(name));
+ EXPECT_FALSE(manager_.HasCompletedProbe(name));
+ EXPECT_FALSE(manager_.IsDomainClaimed(name));
+ };
+ validate();
+
+ return manager_.GetOngoingMockProbeByTarget(name);
+ }
+
+ StrictMock<MockMdnsProbe>* ExpectProbeCompleted(const DomainName& name) {
+ // Get around limitations of using an assert in a function with a return
+ // value.
+ auto validate = [this, &name]() {
+ EXPECT_FALSE(manager_.HasOngoingProbe(name));
+ ASSERT_TRUE(manager_.HasCompletedProbe(name));
+ EXPECT_TRUE(manager_.IsDomainClaimed(name));
+ };
+ validate();
+
+ return manager_.GetCompletedMockProbe(name);
+ }
+
+ StrictMock<MockMdnsProbe>* SetUpCompletedProbe(const DomainName& name,
+ const IPAddress& address) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok());
+ EXPECT_CALL(callback_, OnDomainFound(name, name));
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name);
+ manager_.OnProbeSuccess(ongoing_probe);
+ ExpectProbeCompleted(name);
+ testing::Mock::VerifyAndClearExpectations(ongoing_probe);
+
+ return ongoing_probe;
+ }
+
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ FakeUdpSocket socket_;
+ StrictMock<MockMdnsSender> sender_;
+ MdnsReceiver receiver_;
+ MdnsRandom random_;
+ StrictMock<TestMdnsProbeManager> manager_;
+ MockDomainConfirmedProvider callback_;
+
+ const DomainName name_{"test", "_googlecast", "_tcp", "local"};
+ const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"};
+ const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
+
+ // When used to create address records A, B, C, A > B because comparison of
+ // the rdata in each results in the comparison of endpoints, for which
+ // address_b_ < address_a_. A < C because A is DnsType kA with value 1 and
+ // C is DnsType kAAAA with value 28.
+ const IPAddress address_a_{192, 168, 0, 0};
+ const IPAddress address_b_{190, 160, 0, 0};
+ const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708,
+ 0x090a, 0x0b0c, 0x0d0e, 0x0f10};
+ const IPEndpoint endpoint_{{192, 168, 0, 0}, 80};
+};
+
+TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ ExpectProbeOngoing(name_);
+ EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
+
+ EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+
+ EXPECT_CALL(callback_, OnDomainFound(name_, name_));
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+ manager_.OnProbeSuccess(ongoing_probe);
+ EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
+ testing::Mock::VerifyAndClearExpectations(ongoing_probe);
+
+ EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+
+ StrictMock<MockMdnsProbe>* completed_probe = ExpectProbeCompleted(name_);
+ EXPECT_EQ(ongoing_probe, completed_probe);
+ EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
+}
+
+TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) {
+ EXPECT_FALSE(manager_.StopProbe(name_).ok());
+ ExpectProbeStopped(name_);
+
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ ExpectProbeOngoing(name_);
+
+ EXPECT_TRUE(manager_.StopProbe(name_).ok());
+ ExpectProbeStopped(name_);
+
+ SetUpCompletedProbe(name_, address_a_);
+
+ EXPECT_FALSE(manager_.StopProbe(name_).ok());
+ ExpectProbeCompleted(name_);
+}
+
+TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) {
+ const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
+ manager_.RespondToProbeQuery(query, endpoint_);
+}
+
+TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) {
+ SetUpCompletedProbe(name_, address_a_);
+
+ const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
+ EXPECT_CALL(sender_, SendMessage(_, endpoint_))
+ .WillOnce([this](const MdnsMessage& message,
+ const IPEndpoint& endpoint) -> Error {
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA);
+ EXPECT_EQ(message.answers()[0].name(), this->name_);
+ return Error::None();
+ });
+ manager_.RespondToProbeQuery(query, endpoint_);
+}
+
+TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+
+ // If the probe message received matches the currently running probe, do
+ // nothing.
+ MdnsMessage query = CreateProbeQueryMessage(name_, address_a_);
+ manager_.RespondToProbeQuery(query, endpoint_);
+
+ // If the probe message received is less than the ongoing probe, ignore the
+ // incoming probe.
+ query = CreateProbeQueryMessage(name_, address_b_);
+ manager_.RespondToProbeQuery(query, endpoint_);
+
+ // If the probe message received is greater than the ongoing probe, postpone
+ // the currently running probe.
+ query = CreateProbeQueryMessage(name_, address_c_);
+ EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
+ manager_.RespondToProbeQuery(query, endpoint_);
+}
+
+TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+
+ // For the below tests, note that if records A, B, C are generated from
+ // addresses |address_a_|, |address_b_|, and |address_c_| respectively,
+ // then B < A < C.
+ //
+ // If the received records have one record less than the tested record, they
+ // are sorted and the lowest record is compared.
+ MdnsMessage query = CreateProbeQueryMessage(name_, address_b_);
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
+ manager_.RespondToProbeQuery(query, endpoint_);
+
+ query = CreateProbeQueryMessage(name_, address_c_);
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
+ manager_.RespondToProbeQuery(query, endpoint_);
+
+ query = CreateProbeQueryMessage(name_, address_a_);
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
+ manager_.RespondToProbeQuery(query, endpoint_);
+
+ // If the probe message received has the same first record as what's being
+ // compared and the query has more records, the query wins.
+ query = CreateProbeQueryMessage(name_, address_a_);
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
+ EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
+ manager_.RespondToProbeQuery(query, endpoint_);
+ testing::Mock::VerifyAndClearExpectations(ongoing_probe);
+
+ query = CreateProbeQueryMessage(name_, address_c_);
+ query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_));
+ EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
+ manager_.RespondToProbeQuery(query, endpoint_);
+}
+
+TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+ EXPECT_TRUE(manager_.StopProbe(name_).ok());
+ manager_.OnProbeSuccess(ongoing_probe);
+}
+
+TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+ EXPECT_TRUE(manager_.StopProbe(name_).ok());
+ manager_.OnProbeFailure(ongoing_probe);
+}
+
+TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) {
+ // This test first starts a probe with domain |name_retry_| so that when
+ // probe with domain |name_| fails, the newly generated domain with equal
+ // |name_retry_|.
+ StrictMock<MockMdnsProbe>* ongoing_probe =
+ SetUpCompletedProbe(name_retry_, address_a_);
+
+ // Because |name_retry_| has already succeeded, the retry logic should skip
+ // over re-querying for |name_retry_| and jump right to success.
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ ongoing_probe = ExpectProbeOngoing(name_);
+ EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
+ manager_.OnProbeFailure(ongoing_probe);
+ ExpectProbeStopped(name_);
+ ExpectProbeCompleted(name_retry_);
+}
+
+TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) {
+ EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
+ StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
+ manager_.OnProbeFailure(ongoing_probe);
+ ExpectProbeStopped(name_);
+ ongoing_probe = ExpectProbeOngoing(name_retry_);
+ EXPECT_EQ(ongoing_probe->target_name(), name_retry_);
+
+ EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
+ manager_.OnProbeSuccess(ongoing_probe);
+ ExpectProbeCompleted(name_retry_);
+ ExpectProbeStopped(name_);
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc
new file mode 100644
index 00000000000..d69e45b4b09
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_unittest.cc
@@ -0,0 +1,128 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+#include "discovery/mdns/mdns_probe.h"
+
+#include <memory>
+#include <utility>
+
+#include "discovery/mdns/mdns_probe_manager.h"
+#include "discovery/mdns/mdns_querier.h"
+#include "discovery/mdns/mdns_random.h"
+#include "discovery/mdns/mdns_receiver.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "platform/test/fake_udp_socket.h"
+
+using testing::_;
+using testing::Invoke;
+using testing::Return;
+using testing::StrictMock;
+
+namespace openscreen {
+namespace discovery {
+
+class MockMdnsSender : public MdnsSender {
+ public:
+ explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
+ MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
+ MOCK_METHOD2(SendMessage,
+ Error(const MdnsMessage& message, const IPEndpoint& endpoint));
+};
+
+class MockObserver : public MdnsProbeImpl::Observer {
+ public:
+ MOCK_METHOD1(OnProbeSuccess, void(MdnsProbe*));
+ MOCK_METHOD1(OnProbeFailure, void(MdnsProbe*));
+};
+
+class MdnsProbeTests : public testing::Test {
+ public:
+ MdnsProbeTests()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ socket_(&task_runner_),
+ sender_(&socket_) {
+ EXPECT_EQ(task_runner_.delayed_task_count(), 0);
+ probe_ = CreateProbe();
+ EXPECT_EQ(task_runner_.delayed_task_count(), 1);
+ }
+
+ protected:
+ std::unique_ptr<MdnsProbeImpl> CreateProbe() {
+ return std::make_unique<MdnsProbeImpl>(&sender_, &receiver_, &random_,
+ &task_runner_, FakeClock::now,
+ &observer_, name_, address_v4_);
+ }
+
+ MdnsMessage CreateMessage(const DomainName& domain) {
+ MdnsMessage message(0, MessageType::Response);
+ SrvRecordRdata rdata(0, 0, 80, domain);
+ MdnsRecord record(std::move(domain), DnsType::kSRV, DnsClass::kIN,
+ RecordType::kUnique, std::chrono::seconds(1),
+ std::move(rdata));
+ message.AddAnswer(record);
+ return message;
+ }
+
+ void OnMessageReceived(const MdnsMessage& message) {
+ probe_->OnMessageReceived(message);
+ }
+
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ FakeUdpSocket socket_;
+ StrictMock<MockMdnsSender> sender_;
+ MdnsReceiver receiver_;
+ MdnsRandom random_;
+ StrictMock<MockObserver> observer_;
+
+ std::unique_ptr<MdnsProbeImpl> probe_;
+
+ const DomainName name_{"test", "_googlecast", "_tcp", "local"};
+ const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
+
+ const IPAddress address_v4_{192, 168, 0, 0};
+ const IPEndpoint endpoint_v4_{address_v4_, 80};
+};
+
+TEST_F(MdnsProbeTests, TestNoCancelationFlow) {
+ EXPECT_CALL(sender_, SendMulticast(_));
+ clock_.Advance(kDelayBetweenProbeQueries);
+ EXPECT_EQ(task_runner_.delayed_task_count(), 1);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ EXPECT_CALL(sender_, SendMulticast(_));
+ clock_.Advance(kDelayBetweenProbeQueries);
+ EXPECT_EQ(task_runner_.delayed_task_count(), 1);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ EXPECT_CALL(sender_, SendMulticast(_));
+ clock_.Advance(kDelayBetweenProbeQueries);
+ EXPECT_EQ(task_runner_.delayed_task_count(), 1);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ EXPECT_CALL(observer_, OnProbeSuccess(probe_.get())).Times(1);
+ clock_.Advance(kDelayBetweenProbeQueries);
+ EXPECT_EQ(task_runner_.delayed_task_count(), 0);
+}
+
+TEST_F(MdnsProbeTests, CancelationWhenMatchingMessageReceived) {
+ EXPECT_CALL(observer_, OnProbeFailure(probe_.get())).Times(1);
+ OnMessageReceived(CreateMessage(name_));
+}
+
+TEST_F(MdnsProbeTests, TestNoCancelationOnUnrelatedMessages) {
+ OnMessageReceived(CreateMessage(name2_));
+
+ EXPECT_CALL(sender_, SendMulticast(_));
+ clock_.Advance(kDelayBetweenProbeQueries);
+ EXPECT_EQ(task_runner_.delayed_task_count(), 1);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc
index e1ebde80f19..51aa30e7dfe 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc
@@ -4,61 +4,375 @@
#include "discovery/mdns/mdns_publisher.h"
+#include <chrono>
+#include <cmath>
+
+#include "discovery/common/config.h"
+#include "discovery/mdns/mdns_probe_manager.h"
+#include "discovery/mdns/mdns_records.h"
#include "discovery/mdns/mdns_sender.h"
#include "platform/api/task_runner.h"
+#include "platform/base/trivial_clock_traits.h"
namespace openscreen {
namespace discovery {
+namespace {
+
+// Minimum delay between announcements of a given record in seconds.
+constexpr std::chrono::seconds kMinAnnounceDelay{1};
+
+// Intervals between successive announcements must increase by at least a
+// factor of 2.
+constexpr int kIntervalIncreaseFactor = 2;
+
+// TTL for a goodbye record in seconds. This constant is called out in RFC 6762
+// section 10.1.
+constexpr std::chrono::seconds kGoodbyeTtl{0};
+
+// Timespan between sending batches of announcement and goodbye records, in
+// microseconds.
+constexpr Clock::duration kDelayBetweenBatchedRecords =
+ std::chrono::milliseconds(20);
+
+inline MdnsRecord CreateGoodbyeRecord(const MdnsRecord& record) {
+ if (record.ttl() == kGoodbyeTtl) {
+ return record;
+ }
+ return MdnsRecord(record.name(), record.dns_type(), record.dns_class(),
+ record.record_type(), kGoodbyeTtl, record.rdata());
+}
+
+inline void ValidateRecord(const MdnsRecord& record) {
+ OSP_DCHECK(record.dns_type() != DnsType::kANY);
+ OSP_DCHECK(record.dns_class() != DnsClass::kANY);
+}
+
+} // namespace
-MdnsPublisher::MdnsPublisher(MdnsQuerier* querier,
- MdnsSender* sender,
- platform::TaskRunner* task_runner,
- MdnsRandom* random_delay)
- : querier_(querier),
- sender_(sender),
+MdnsPublisher::MdnsPublisher(MdnsSender* sender,
+ MdnsProbeManager* ownership_manager,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ const Config& config)
+ : sender_(sender),
+ ownership_manager_(ownership_manager),
task_runner_(task_runner),
- random_delay_(random_delay) {
- OSP_DCHECK(querier_);
+ now_function_(now_function),
+ max_announcement_attempts_(config.new_record_announcement_count) {
+ OSP_DCHECK(ownership_manager_);
OSP_DCHECK(sender_);
OSP_DCHECK(task_runner_);
- OSP_DCHECK(random_delay_);
+ OSP_DCHECK_GE(max_announcement_attempts_, 0);
}
-MdnsPublisher::~MdnsPublisher() = default;
+MdnsPublisher::~MdnsPublisher() {
+ if (batch_records_alarm_.has_value()) {
+ batch_records_alarm_.value().Cancel();
+ ProcessRecordQueue();
+ }
+}
Error MdnsPublisher::RegisterRecord(const MdnsRecord& record) {
- // TODO(rwkeane): Implement this method.
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (record.dns_type() == DnsType::kNSEC) {
+ return Error::Code::kParameterInvalid;
+ }
+ ValidateRecord(record);
+
+ if (!IsRecordNameClaimed(record)) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ const DomainName& name = record.name();
+ auto it = records_.emplace(name, std::vector<RecordAnnouncerPtr>{}).first;
+ for (const RecordAnnouncerPtr& publisher : it->second) {
+ if (publisher->record() == record) {
+ return Error::Code::kItemAlreadyExists;
+ }
+ }
+
+ OSP_DVLOG << "Registering record of type '" << record.dns_type() << "'";
+
+ it->second.push_back(CreateAnnouncer(record));
+
return Error::None();
}
+Error MdnsPublisher::UnregisterRecord(const MdnsRecord& record) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (record.dns_type() == DnsType::kNSEC) {
+ return Error::Code::kParameterInvalid;
+ }
+ ValidateRecord(record);
+
+ OSP_DVLOG << "Unregistering record of type '" << record.dns_type() << "'";
+
+ return RemoveRecord(record, true);
+}
+
Error MdnsPublisher::UpdateRegisteredRecord(const MdnsRecord& old_record,
const MdnsRecord& new_record) {
- // TODO(rwkeane): Implement this method.
- return Error::None();
-}
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
-Error MdnsPublisher::UnregisterRecord(const MdnsRecord& record) {
- // TODO(rwkeane): Implement this method.
- return Error::None();
+ if (old_record.dns_type() == DnsType::kNSEC) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ if (old_record.dns_type() == DnsType::kPTR) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ // Check that the old record and new record are compatible.
+ if (old_record.name() != new_record.name() ||
+ old_record.dns_type() != new_record.dns_type() ||
+ old_record.dns_class() != new_record.dns_class() ||
+ old_record.record_type() != new_record.record_type()) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ OSP_DVLOG << "Updating record of type '" << new_record.dns_type() << "'";
+
+ // Remove the old record. Per RFC 6762 section 8.4, a goodbye message will not
+ // be sent, as all records which can be removed here are unique records, which
+ // will be overwritten during the announcement phase when the updated record
+ // is re-registered due to the cache-flush-bit's presence.
+ const Error remove_result = RemoveRecord(old_record, false);
+ if (!remove_result.ok()) {
+ return remove_result;
+ }
+
+ // Register the new record.
+ return RegisterRecord(new_record);
}
-bool MdnsPublisher::IsExclusiveOwner(const DomainName& name) {
- // TODO(rwkeane): Implement this method.
- return false;
+size_t MdnsPublisher::GetRecordCount() const {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ size_t count = 0;
+ for (const auto& pair : records_) {
+ count += pair.second.size();
+ }
+
+ return count;
}
bool MdnsPublisher::HasRecords(const DomainName& name,
DnsType type,
DnsClass clazz) {
- // TODO(rwkeane): Implement this method.
- return false;
+ return !GetRecords(name, type, clazz).empty();
}
+
std::vector<MdnsRecord::ConstRef> MdnsPublisher::GetRecords(
const DomainName& name,
DnsType type,
DnsClass clazz) {
- // TODO(rwkeane): Implement this method.
- return std::vector<MdnsRecord::ConstRef>();
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ std::vector<MdnsRecord::ConstRef> records;
+ auto it = records_.find(name);
+ if (it != records_.end()) {
+ for (const RecordAnnouncerPtr& announcer : it->second) {
+ OSP_DCHECK(announcer.get());
+ const DnsType record_dns_type = announcer->record().dns_type();
+ const DnsClass record_dns_class = announcer->record().dns_class();
+ if ((type == DnsType::kANY || type == record_dns_type) &&
+ (clazz == DnsClass::kANY || clazz == record_dns_class)) {
+ records.push_back(announcer->record());
+ }
+ }
+ }
+
+ return records;
+}
+
+std::vector<MdnsRecord::ConstRef> MdnsPublisher::GetPtrRecords(DnsClass clazz) {
+ std::vector<MdnsRecord::ConstRef> records;
+
+ // There should be few records associated with any given domain name, so it is
+ // simpler and less error prone to iterate across all records than to check
+ // the domain name against format '[^.]+\.(_tcp)|(_udp)\..*''
+ for (auto it = records_.begin(); it != records_.end(); it++) {
+ for (const RecordAnnouncerPtr& announcer : it->second) {
+ OSP_DCHECK(announcer.get());
+ const DnsType record_dns_type = announcer->record().dns_type();
+ if (record_dns_type != DnsType::kPTR) {
+ continue;
+ }
+
+ const DnsClass record_dns_class = announcer->record().dns_class();
+ if ((clazz == DnsClass::kANY || clazz == record_dns_class)) {
+ records.push_back(announcer->record());
+ }
+ }
+ }
+
+ return records;
+}
+
+Error MdnsPublisher::RemoveRecord(const MdnsRecord& record,
+ bool should_announce_deletion) {
+ const DomainName& name = record.name();
+
+ // Check for the domain and fail if it's not found.
+ const auto it = records_.find(name);
+ if (it == records_.end()) {
+ return Error::Code::kItemNotFound;
+ }
+
+ // Check for the record to be removed.
+ const auto records_it =
+ std::find_if(it->second.begin(), it->second.end(),
+ [&record](const RecordAnnouncerPtr& publisher) {
+ return publisher->record() == record;
+ });
+ if (records_it == it->second.end()) {
+ return Error::Code::kItemNotFound;
+ }
+ if (!should_announce_deletion) {
+ (*records_it)->DisableGoodbyeMessageTransmission();
+ }
+
+ it->second.erase(records_it);
+ if (it->second.empty()) {
+ records_.erase(it);
+ }
+
+ return Error::None();
+}
+
+bool MdnsPublisher::IsRecordNameClaimed(const MdnsRecord& record) const {
+ const DomainName& name =
+ record.dns_type() == DnsType::kPTR
+ ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain()
+ : record.name();
+ return ownership_manager_->IsDomainClaimed(name);
+}
+
+MdnsPublisher::RecordAnnouncer::RecordAnnouncer(
+ MdnsRecord record,
+ MdnsPublisher* publisher,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ int target_announcement_attempts)
+ : publisher_(publisher),
+ task_runner_(task_runner),
+ now_function_(now_function),
+ record_(std::move(record)),
+ alarm_(now_function_, task_runner_),
+ target_announcement_attempts_(target_announcement_attempts) {
+ OSP_DCHECK(publisher_);
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(record_.ttl() != Clock::duration::zero());
+
+ QueueAnnouncement();
+}
+
+MdnsPublisher::RecordAnnouncer::~RecordAnnouncer() {
+ alarm_.Cancel();
+ if (should_send_goodbye_message_) {
+ QueueGoodbye();
+ }
+}
+
+void MdnsPublisher::RecordAnnouncer::QueueGoodbye() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ publisher_->QueueRecord(CreateGoodbyeRecord(record_));
+}
+
+void MdnsPublisher::RecordAnnouncer::QueueAnnouncement() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (attempts_ >= target_announcement_attempts_) {
+ return;
+ }
+
+ publisher_->QueueRecord(record_);
+
+ const Clock::duration new_delay = GetNextAnnounceDelay();
+ attempts_++;
+ alarm_.ScheduleFromNow([this]() { QueueAnnouncement(); }, new_delay);
+}
+
+void MdnsPublisher::QueueRecord(MdnsRecord record) {
+ if (!batch_records_alarm_.has_value()) {
+ OSP_DCHECK(records_to_send_.empty());
+ batch_records_alarm_.emplace(now_function_, task_runner_);
+ batch_records_alarm_.value().ScheduleFromNow(
+ [this]() { ProcessRecordQueue(); }, kDelayBetweenBatchedRecords);
+ }
+
+ // Check that we aren't announcing and goodbye'ing a record in the same batch.
+ // We expect to be sending no more than 5 records at a time, so don't worry
+ // about iterating across this vector for each insert.
+ auto goodbye = CreateGoodbyeRecord(record);
+ auto existing_record_it =
+ std::find_if(records_to_send_.begin(), records_to_send_.end(),
+ [&goodbye](const MdnsRecord& record) {
+ return goodbye == CreateGoodbyeRecord(record);
+ });
+
+ // If we didn't find it, simply add it to the queue. Else, only send the
+ // goodbye record.
+ if (existing_record_it == records_to_send_.end()) {
+ records_to_send_.push_back(std::move(record));
+ } else if (*existing_record_it == goodbye) {
+ // This means that the goodbye record is already queued to be sent. This
+ // means that there is no reason to also announce it, so exit early.
+ return;
+ } else if (record == goodbye) {
+ // This means that we are sending a goodbye record right as it would also
+ // be announced. Skip the announcement since the record is being
+ // unregistered.
+ *existing_record_it = std::move(record);
+ } else if (record == *existing_record_it) {
+ // This case shouldn't happen, but there is no work to do if it does. Log
+ // to surface that something weird is going on.
+ OSP_LOG_INFO << "Same record being announced multiple times.";
+ } else {
+ // This case should never occur. Support it just in case, but log to
+ // surface that something weird is happening.
+ OSP_LOG_INFO << "Updating the same record multiple times with multiple "
+ "TTL values.";
+ }
+}
+
+void MdnsPublisher::ProcessRecordQueue() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (records_to_send_.empty()) {
+ return;
+ }
+
+ MdnsMessage message(CreateMessageId(), MessageType::Response);
+ for (auto it = records_to_send_.begin(); it != records_to_send_.end();) {
+ if (message.CanAddRecord(*it)) {
+ message.AddAnswer(std::move(*it++));
+ } else if (message.answers().empty()) {
+ // This case should never happen, because it means a record is too large
+ // to fit into its own message.
+ OSP_LOG << "Encountered unreasonably large message in cache. Skipping "
+ << "known answer in suppressions...";
+ it++;
+ } else {
+ sender_->SendMulticast(message);
+ message = MdnsMessage(CreateMessageId(), MessageType::Response);
+ }
+ }
+
+ if (!message.answers().empty()) {
+ sender_->SendMulticast(message);
+ }
+
+ batch_records_alarm_ = absl::nullopt;
+ records_to_send_.clear();
+}
+
+Clock::duration MdnsPublisher::RecordAnnouncer::GetNextAnnounceDelay() {
+ return std::chrono::duration_cast<Clock::duration>(
+ kMinAnnounceDelay * pow(kIntervalIncreaseFactor, attempts_));
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h
index 022a94df90f..4b418312104 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h
@@ -5,60 +5,193 @@
#ifndef DISCOVERY_MDNS_MDNS_PUBLISHER_H_
#define DISCOVERY_MDNS_MDNS_PUBLISHER_H_
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/types/optional.h"
#include "discovery/mdns/mdns_records.h"
#include "discovery/mdns/mdns_responder.h"
+#include "util/alarm.h"
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace discovery {
-// TODO(rwkeane): Add API for claiming a DomainName as described in RFC 6762
-// Section 8.1's probing phase.
+struct Config;
+class MdnsProbeManager;
+class MdnsRandom;
+class MdnsSender;
+class MdnsQuerier;
+
+// This class is responsible for both tracking what records have been registered
+// to mDNS as well as publishing new mDNS records to the network.
+// When a new record is published, it will be announced 8 times, starting at an
+// interval of 1 second, with the interval doubling each successive
+// announcement. This same announcement process is followed when an existing
+// record is updated. When it is removed, a Goodbye message must be sent if the
+// record is unique.
+//
+// Prior to publishing a record, the domain name for this service instance must
+// be claimed using the ClaimExclusiveOwnership() function. This function probes
+// the network to determine whether the chosen name exists, modifying the
+// chosen name as described in RFC 6762 if a collision is found.
+//
+// NOTE: All MdnsPublisher instances must be run on the same task runner thread,
+// due to the shared announce + goodbye message queue.
class MdnsPublisher : public MdnsResponder::RecordHandler {
public:
- // |querier|, |sender|, |task_runner|, and |random_delay| must all persist for
- // the duration of this object's lifetime
- MdnsPublisher(MdnsQuerier* querier,
- MdnsSender* sender,
- platform::TaskRunner* task_runner,
- MdnsRandom* random_delay);
- ~MdnsPublisher();
+ // |sender|, |ownership_manager|, and |task_runner| must all persist for the
+ // duration of this object's lifetime
+ MdnsPublisher(MdnsSender* sender,
+ MdnsProbeManager* ownership_manager,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ const Config& config);
+ ~MdnsPublisher() override;
// Registers a new mDNS record for advertisement by this service. For A, AAAA,
// SRV, and TXT records, the domain name must have already been claimed by the
// ClaimExclusiveOwnership() method and for PTR records the name being pointed
- // to must have been claimed in the same fashion.
+ // to must have been claimed in the same fashion, but the domain name in the
+ // top-level MdnsRecord entity does not.
+ // NOTE: NSEC records cannot be registered, and doing so will return an error.
Error RegisterRecord(const MdnsRecord& record);
// Updates the existing record with name matching the name of the new record.
+ // NOTE: This method is not valid for PTR records.
Error UpdateRegisteredRecord(const MdnsRecord& old_record,
const MdnsRecord& new_record);
- // Stops advertising the provided record. If no more records with the provided
- // name are bing advertised after this call's completion, then ownership of
- // the name is released.
+ // Stops advertising the provided record.
Error UnregisterRecord(const MdnsRecord& record);
+ // Returns the total number of records currently registered;
+ size_t GetRecordCount() const;
+
OSP_DISALLOW_COPY_AND_ASSIGN(MdnsPublisher);
private:
+ // Class responsible for sending announcement and goodbye messages for
+ // MdnsRecord instances when they are published, updated, or unpublished. The
+ // announcement messages will be sent |target_announcement_attempts| times,
+ // first at an interval of 1 second apart, and then with delay increasing by a
+ // factor of 2 with each successive announcement.
+ // NOTE: |publisher| must be the MdnsPublisher instance from which this
+ // instance was created.
+ class RecordAnnouncer {
+ public:
+ RecordAnnouncer(MdnsRecord record,
+ MdnsPublisher* publisher,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ int max_announcement_attempts);
+ RecordAnnouncer(const RecordAnnouncer& other) = delete;
+ RecordAnnouncer(RecordAnnouncer&& other) noexcept = delete;
+ ~RecordAnnouncer();
+
+ RecordAnnouncer& operator=(const RecordAnnouncer& other) = delete;
+ RecordAnnouncer& operator=(RecordAnnouncer&& other) noexcept = delete;
+
+ const MdnsRecord& record() const { return record_; }
+
+ // Specifies whether goodbye messages should not be sent when this announcer
+ // is destroyed. This should only be called as part of the 'Update' flow,
+ // for records which should not send this message.
+ void DisableGoodbyeMessageTransmission() {
+ should_send_goodbye_message_ = false;
+ }
+
+ private:
+ // Gets the delay required before the next announcement message is sent.
+ Clock::duration GetNextAnnounceDelay();
+
+ // When announce + goodbye messages are ready to be sent, they are queued
+ // up. Every 20ms, if there are any messages to send out, these records are
+ // batched up and sent out.
+ void QueueGoodbye();
+ void QueueAnnouncement();
+
+ MdnsPublisher* const publisher_;
+ TaskRunner* const task_runner_;
+ const ClockNowFunctionPtr now_function_;
+
+ // Whether or not goodbye messages should be sent.
+ bool should_send_goodbye_message_ = true;
+
+ // Record to send.
+ const MdnsRecord record_;
+
+ // Alarm used to cancel future resend attempts if this object is deleted.
+ Alarm alarm_;
+
+ // Number of attempts at sending this record which have occurred so far.
+ int attempts_ = 0;
+
+ // Number of times to announce a newly published record.
+ const int target_announcement_attempts_;
+ };
+
+ using RecordAnnouncerPtr = std::unique_ptr<RecordAnnouncer>;
+
+ friend class MdnsPublisherTesting;
+
+ // Creates a new published from the provided record.
+ RecordAnnouncerPtr CreateAnnouncer(MdnsRecord record) {
+ return std::make_unique<RecordAnnouncer>(std::move(record), this,
+ task_runner_, now_function_,
+ max_announcement_attempts_);
+ }
+
+ // Removes the given record from the |records_| map. A goodbye record is only
+ // sent for this removal if |should_announce_deletion| is true.
+ Error RemoveRecord(const MdnsRecord& record, bool should_announce_deletion);
+
+ // Returns whether the provided record has had its name claimed so far.
+ bool IsRecordNameClaimed(const MdnsRecord& record) const;
+
+ // Processes the |records_to_send_| queue, sending out the records together as
+ // a single MdnsMessage.
+ void ProcessRecordQueue();
+
+ // Adds a new record to the |records_to_send_| queue or ensures that the
+ // record with lower ttl is present if it differs from an existing record by
+ // only that one field.
+ void QueueRecord(MdnsRecord record);
+
// MdnsResponder::RecordHandler overrides.
- bool IsExclusiveOwner(const DomainName& name) override;
bool HasRecords(const DomainName& name,
DnsType type,
DnsClass clazz) override;
std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
DnsType type,
DnsClass clazz) override;
+ std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override;
- MdnsQuerier* const querier_;
MdnsSender* const sender_;
- platform::TaskRunner* const task_runner_;
- MdnsRandom* const random_delay_;
+ MdnsProbeManager* const ownership_manager_;
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+
+ // Alarm to cancel batching of records when this class is destroyed, and
+ // instead send them immediately. Variable is only set when it is in use.
+ absl::optional<Alarm> batch_records_alarm_;
+
+ // Number of times to announce a newly published record.
+ const int max_announcement_attempts_;
+
+ // The queue for announce and goodbye records to be sent periodically.
+ std::vector<MdnsRecord> records_to_send_;
+
+ // Stores mDNS records that have been published. The keys here are domain
+ // names for valid mDNS Records, and the values are the RecordAnnouncer
+ // entities associated with all published MdnsRecords for the keyed domain.
+ // These are responsible for publishing a specific MdnsRecord, announcing it
+ // when its created and sending a goodbye record when it's deleted.
+ std::map<DomainName, std::vector<RecordAnnouncerPtr>> records_;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc
new file mode 100644
index 00000000000..c05019f7468
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc
@@ -0,0 +1,465 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_publisher.h"
+
+#include <chrono> // NOLINT
+#include <vector>
+
+#include "discovery/common/config.h"
+#include "discovery/mdns/mdns_probe_manager.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "discovery/mdns/testing/mdns_test_util.h"
+#include "platform/test/fake_task_runner.h"
+#include "platform/test/fake_udp_socket.h"
+
+using testing::_;
+using testing::Invoke;
+using testing::Return;
+using testing::StrictMock;
+
+namespace openscreen {
+namespace discovery {
+namespace {
+
+constexpr Clock::duration kAnnounceGoodbyeDelay = std::chrono::milliseconds(25);
+
+bool ContainsRecord(const std::vector<MdnsRecord::ConstRef>& records,
+ MdnsRecord record) {
+ return std::find_if(records.begin(), records.end(),
+ [&record](const MdnsRecord& ref) {
+ return ref == record;
+ }) != records.end();
+}
+
+} // namespace
+
+class MockMdnsSender : public MdnsSender {
+ public:
+ explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
+
+ MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
+ MOCK_METHOD2(SendMessage,
+ Error(const MdnsMessage& message, const IPEndpoint& endpoint));
+};
+
+class MockProbeManager : public MdnsProbeManager {
+ public:
+ MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
+ MOCK_METHOD2(RespondToProbeQuery,
+ void(const MdnsMessage&, const IPEndpoint&));
+};
+
+class MdnsPublisherTesting : public MdnsPublisher {
+ public:
+ using MdnsPublisher::GetPtrRecords;
+ using MdnsPublisher::GetRecords;
+ using MdnsPublisher::MdnsPublisher;
+
+ bool IsNonPtrRecordPresent(const DomainName& name) {
+ auto it = records_.find(name);
+ if (it == records_.end()) {
+ return false;
+ }
+
+ return std::find_if(it->second.begin(), it->second.end(),
+ [](const RecordAnnouncerPtr& announcer) {
+ return announcer->record().dns_type() !=
+ DnsType::kPTR;
+ }) != it->second.end();
+ }
+};
+
+class MdnsPublisherTest : public testing::Test {
+ public:
+ MdnsPublisherTest()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ socket_(&task_runner_),
+ sender_(&socket_),
+ publisher_(&sender_,
+ &probe_manager_,
+ &task_runner_,
+ FakeClock::now,
+ config_) {}
+
+ ~MdnsPublisherTest() {
+ // Clear out any remaining calls in the task runner queue.
+ clock_.Advance(
+ std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)));
+ }
+
+ protected:
+ Error IsAnnounced(const MdnsRecord& original, const MdnsMessage& message) {
+ EXPECT_EQ(message.type(), MessageType::Response);
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{0});
+ EXPECT_EQ(message.answers().size(), size_t{1});
+
+ const MdnsRecord& sent = message.answers()[0];
+ EXPECT_EQ(original.name(), sent.name());
+ EXPECT_EQ(original.dns_type(), sent.dns_type());
+ EXPECT_EQ(original.dns_class(), sent.dns_class());
+ EXPECT_EQ(original.record_type(), sent.record_type());
+ EXPECT_EQ(original.rdata(), sent.rdata());
+ EXPECT_EQ(original.ttl(), sent.ttl());
+ return Error::None();
+ }
+
+ Error IsGoodbyeRecord(const MdnsRecord& original,
+ const MdnsMessage& message) {
+ EXPECT_EQ(message.type(), MessageType::Response);
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{0});
+ EXPECT_EQ(message.answers().size(), size_t{1});
+
+ const MdnsRecord& sent = message.answers()[0];
+ EXPECT_EQ(original.name(), sent.name());
+ EXPECT_EQ(original.dns_type(), sent.dns_type());
+ EXPECT_EQ(original.dns_class(), sent.dns_class());
+ EXPECT_EQ(original.record_type(), sent.record_type());
+ EXPECT_EQ(original.rdata(), sent.rdata());
+ EXPECT_EQ(std::chrono::seconds(0), sent.ttl());
+ return Error::None();
+ }
+
+ void CheckPublishedRecords(const DomainName& domain,
+ DnsType type,
+ std::vector<MdnsRecord> expected_records) {
+ EXPECT_EQ(publisher_.GetRecordCount(), expected_records.size());
+ auto records = publisher_.GetRecords(domain, type, DnsClass::kIN);
+ for (const auto& record : expected_records) {
+ EXPECT_TRUE(ContainsRecord(records, record));
+ }
+ }
+
+ void TestUniqueRecordRegistrationWorkflow(MdnsRecord record,
+ MdnsRecord record2) {
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
+ .WillRepeatedly(Return(true));
+ DnsType type = record.dns_type();
+
+ // Check preconditions.
+ ASSERT_EQ(record.dns_type(), record2.dns_type());
+ auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN);
+ ASSERT_EQ(publisher_.GetRecordCount(), size_t{0});
+ ASSERT_EQ(records.size(), size_t{0});
+ ASSERT_NE(record, record2);
+ ASSERT_TRUE(records.empty());
+
+ // Register a new record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Re-register the same record.
+ EXPECT_FALSE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Update a record that doesn't exist
+ EXPECT_FALSE(publisher_.UpdateRegisteredRecord(record2, record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Update an existing record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record2, message);
+ });
+ EXPECT_TRUE(publisher_.UpdateRegisteredRecord(record, record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record2});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Add back the original record
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record, record2});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Delete an existing record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
+ return IsGoodbyeRecord(record2, message);
+ });
+ EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Delete a non-existing record.
+ EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {record});
+ EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
+
+ // Delete the last record
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsGoodbyeRecord(record, message);
+ });
+ EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(domain_, type, {});
+ EXPECT_FALSE(publisher_.IsNonPtrRecordPresent(domain_));
+ }
+
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ FakeUdpSocket socket_;
+ StrictMock<MockMdnsSender> sender_;
+ StrictMock<MockProbeManager> probe_manager_;
+ Config config_;
+ MdnsPublisherTesting publisher_;
+
+ DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
+ DomainName ptr_domain_{"_googlecast", "_tcp", "local"};
+};
+
+TEST_F(MdnsPublisherTest, ARecordRegistrationWorkflow) {
+ const MdnsRecord record1 = GetFakeARecord(domain_);
+ const MdnsRecord record2 =
+ GetFakeARecord(domain_, std::chrono::seconds(1000));
+ TestUniqueRecordRegistrationWorkflow(record1, record2);
+}
+
+TEST_F(MdnsPublisherTest, AAAARecordRegistrationWorkflow) {
+ const MdnsRecord record1 = GetFakeAAAARecord(domain_);
+ const MdnsRecord record2 =
+ GetFakeAAAARecord(domain_, std::chrono::seconds(1000));
+ TestUniqueRecordRegistrationWorkflow(record1, record2);
+}
+
+TEST_F(MdnsPublisherTest, TXTRecordRegistrationWorkflow) {
+ const MdnsRecord record1 = GetFakeTxtRecord(domain_);
+ const MdnsRecord record2 =
+ GetFakeTxtRecord(domain_, std::chrono::seconds(1000));
+ TestUniqueRecordRegistrationWorkflow(record1, record2);
+}
+
+TEST_F(MdnsPublisherTest, SRVRecordRegistrationWorkflow) {
+ const MdnsRecord record1 = GetFakeSrvRecord(domain_);
+ const MdnsRecord record2 =
+ GetFakeSrvRecord(domain_, std::chrono::seconds(1000));
+ TestUniqueRecordRegistrationWorkflow(record1, record2);
+}
+
+TEST_F(MdnsPublisherTest, PTRRecordRegistrationWorkflow) {
+ const MdnsRecord record = GetFakePtrRecord(domain_);
+ const MdnsRecord record2 =
+ GetFakePtrRecord(domain_, std::chrono::seconds(1000));
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
+ .WillRepeatedly(Return(true));
+ DnsType type = DnsType::kPTR;
+
+ // Check preconditions.
+ ASSERT_EQ(record.dns_type(), record2.dns_type());
+ ASSERT_EQ(publisher_.GetRecordCount(), size_t{0});
+ auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN);
+ ASSERT_EQ(records.size(), size_t{0});
+ records = publisher_.GetRecords(ptr_domain_, type, DnsClass::kIN);
+ ASSERT_EQ(records.size(), size_t{0});
+ ASSERT_NE(record, record2);
+ ASSERT_TRUE(records.empty());
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0});
+
+ // Register a new record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {record});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
+
+ // Re-register the same record.
+ EXPECT_FALSE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {record});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
+
+ // Register a second record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record2, message);
+ });
+ EXPECT_TRUE(publisher_.RegisterRecord(record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {record, record2});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{2});
+
+ // Delete an existing record.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
+ return IsGoodbyeRecord(record2, message);
+ });
+ EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {record});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
+
+ // Delete a non-existing record.
+ EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {record});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
+
+ // Delete the last record
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsGoodbyeRecord(record, message);
+ });
+ EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ CheckPublishedRecords(ptr_domain_, type, {});
+ ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0});
+}
+
+TEST_F(MdnsPublisherTest, RegisteringUnownedRecordsFail) {
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
+ .WillRepeatedly(Return(false));
+ EXPECT_FALSE(publisher_.RegisterRecord(GetFakePtrRecord(domain_)).ok());
+ EXPECT_FALSE(publisher_.RegisterRecord(GetFakeSrvRecord(domain_)).ok());
+ EXPECT_FALSE(publisher_.RegisterRecord(GetFakeTxtRecord(domain_)).ok());
+ EXPECT_FALSE(publisher_.RegisterRecord(GetFakeARecord(domain_)).ok());
+ EXPECT_FALSE(publisher_.RegisterRecord(GetFakeAAAARecord(domain_)).ok());
+}
+
+TEST_F(MdnsPublisherTest, RegistrationAnnouncesEightTimes) {
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
+ .WillRepeatedly(Return(true));
+ constexpr Clock::duration kOneSecond =
+ std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1));
+
+ // First announce, at registration.
+ const MdnsRecord record = GetFakeARecord(domain_);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+
+ // Second announce, at 2 seconds.
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Third announce, at 4 seconds.
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Fourth announce, at 8 seconds.
+ clock_.Advance(kOneSecond * 3);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Fifth announce, at 16 seconds.
+ clock_.Advance(kOneSecond * 7);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Sixth announce, at 32 seconds.
+ clock_.Advance(kOneSecond * 15);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Seventh announce, at 64 seconds.
+ clock_.Advance(kOneSecond * 31);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Eighth announce, at 128 seconds.
+ clock_.Advance(kOneSecond * 63);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsAnnounced(record, message);
+ });
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // No more announcements
+ clock_.Advance(kOneSecond * 1024);
+ clock_.Advance(kAnnounceGoodbyeDelay);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // Sends goodbye message when removed.
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &record](const MdnsMessage& message) -> Error {
+ return IsGoodbyeRecord(record, message);
+ });
+ EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
+ clock_.Advance(kAnnounceGoodbyeDelay);
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc
index 3c4c635369e..a16b7ef873c 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc
@@ -4,36 +4,230 @@
#include "discovery/mdns/mdns_querier.h"
+#include <vector>
+
+#include "discovery/common/config.h"
+#include "discovery/common/reporting_client.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_sender.h"
-#include "discovery/mdns/mdns_trackers.h"
namespace openscreen {
namespace discovery {
+namespace {
+
+const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = {
+ DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
+
+bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
+ if (record.dns_type() != DnsType::kNSEC) {
+ return false;
+ }
+
+ const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
+
+ // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
+ // record to indicate this is an mDNS NSEC record rather than a traditional
+ // DNS NSEC record.
+ if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
+ nsec.types().end()) {
+ return false;
+ }
+
+ return std::find_if(nsec.types().begin(), nsec.types().end(),
+ [type](DnsType stored_type) {
+ return stored_type == type ||
+ stored_type == DnsType::kANY;
+ }) != nsec.types().end();
+}
+
+} // namespace
+
+MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
+ MdnsQuerier* querier,
+ MdnsSender* sender,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config)
+ : querier_(querier),
+ sender_(sender),
+ random_delay_(random_delay),
+ task_runner_(task_runner),
+ now_function_(now_function),
+ reporting_client_(reporting_client),
+ config_(config) {
+ OSP_DCHECK(sender_);
+ OSP_DCHECK(random_delay_);
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(reporting_client_);
+ OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
+}
+
+std::vector<std::reference_wrapper<const MdnsRecordTracker>>
+MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
+ return Find(name, DnsType::kANY, DnsClass::kANY);
+}
+
+std::vector<std::reference_wrapper<const MdnsRecordTracker>>
+MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
+ DnsType dns_type,
+ DnsClass dns_class) {
+ std::vector<RecordTrackerConstRef> results;
+ auto pair = records_.equal_range(name);
+ for (auto it = pair.first; it != pair.second; it++) {
+ const MdnsRecordTracker& tracker = *it->second;
+ if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
+ (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
+ results.push_back(std::cref(tracker));
+ }
+ }
+
+ return results;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
+ TrackerApplicableCheck check) {
+ auto pair = records_.equal_range(domain);
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second;) {
+ if (check(*it->second)) {
+ lru_order_.erase(it->second);
+ it = records_.erase(it);
+ count++;
+ } else {
+ it++;
+ }
+ }
+
+ return count;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
+ const DomainName& domain,
+ TrackerApplicableCheck check) {
+ auto pair = records_.equal_range(domain);
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
+ if (check(*it->second)) {
+ MoveToEnd(it);
+ it->second->ExpireSoon();
+ count++;
+ }
+ }
+
+ return count;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
+ TrackerApplicableCheck check) {
+ return Update(record, check, [](const MdnsRecordTracker& t) {});
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Update(
+ const MdnsRecord& record,
+ TrackerApplicableCheck check,
+ TrackerChangeCallback on_rdata_update) {
+ auto pair = records_.equal_range(record.name());
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
+ if (check(*it->second)) {
+ auto result = it->second->Update(record);
+
+ if (result.is_error()) {
+ reporting_client_->OnRecoverableError(
+ Error(Error::Code::kUpdateReceivedRecordFailure,
+ result.error().ToString()));
+ continue;
+ }
+
+ count++;
+ if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
+ it->second->ExpireSoon();
+ MoveToEnd(it);
+ } else {
+ MoveToBeginning(it);
+ if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
+ on_rdata_update(*it->second);
+ }
+ }
+ }
+ }
+
+ return count;
+}
+
+const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
+ MdnsRecord record,
+ DnsType dns_type) {
+ auto expiration_callback = [this](const MdnsRecordTracker* tracker,
+ const MdnsRecord& record) {
+ querier_->OnRecordExpired(tracker, record);
+ };
+
+ while (lru_order_.size() >=
+ static_cast<size_t>(config_.querier_max_records_cached)) {
+ // This call erases one of the tracked records.
+ OSP_DVLOG << "Maximum cacheable record count exceeded ("
+ << config_.querier_max_records_cached << ")";
+ lru_order_.back().ExpireNow();
+ }
+
+ auto name = record.name();
+ lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
+ now_function_, random_delay_,
+ std::move(expiration_callback));
+ records_.emplace(std::move(name), lru_order_.begin());
+
+ return lru_order_.front();
+}
+
+void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
+ MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
+ lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
+ it->second = lru_order_.begin();
+}
+
+void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
+ MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
+ lru_order_.splice(lru_order_.end(), lru_order_, it->second);
+ it->second = --lru_order_.end();
+}
MdnsQuerier::MdnsQuerier(MdnsSender* sender,
MdnsReceiver* receiver,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay)
+ MdnsRandom* random_delay,
+ ReportingClient* reporting_client,
+ Config config)
: sender_(sender),
receiver_(receiver),
task_runner_(task_runner),
now_function_(now_function),
- random_delay_(random_delay) {
+ random_delay_(random_delay),
+ reporting_client_(reporting_client),
+ config_(std::move(config)),
+ records_(this,
+ sender_,
+ random_delay_,
+ task_runner_,
+ now_function_,
+ reporting_client_,
+ config_) {
OSP_DCHECK(sender_);
OSP_DCHECK(receiver_);
OSP_DCHECK(task_runner_);
OSP_DCHECK(now_function_);
OSP_DCHECK(random_delay_);
+ OSP_DCHECK(reporting_client_);
- receiver_->SetResponseCallback(
- [this](const MdnsMessage& message) { OnMessageReceived(message); });
+ receiver_->AddResponseCallback(this);
}
MdnsQuerier::~MdnsQuerier() {
- receiver_->SetResponseCallback(nullptr);
+ receiver_->RemoveResponseCallback(this);
}
// NOTE: The code below is range loops instead of std:find_if, for better
@@ -46,6 +240,7 @@ void MdnsQuerier::StartQuery(const DomainName& name,
MdnsRecordChangedCallback* callback) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(callback);
+ OSP_DCHECK(dns_type != DnsType::kNSEC);
// Add a new callback if haven't seen it before
auto callbacks_it = callbacks_.equal_range(name);
@@ -63,12 +258,15 @@ void MdnsQuerier::StartQuery(const DomainName& name,
// Notify the new callback with previously cached records.
// NOTE: In the future, could allow callers to fetch cached records after
// adding a callback, for example to prime the UI.
- auto records_it = records_.equal_range(name);
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- const MdnsRecord& record = entry->second->record();
- if ((dns_type == DnsType::kANY || dns_type == record.dns_type()) &&
- (dns_class == DnsClass::kANY || dns_class == record.dns_class())) {
- callback->OnRecordChanged(record, RecordChangedEvent::kCreated);
+ const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(name, dns_type, dns_class);
+ for (const MdnsRecordTracker& tracker : trackers) {
+ if (!tracker.is_negative_response()) {
+ MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
+ tracker.record_type(), tracker.ttl(),
+ tracker.rdata());
+ callback->OnRecordChanged(std::move(stored_record),
+ RecordChangedEvent::kCreated);
}
}
@@ -92,6 +290,7 @@ void MdnsQuerier::StopQuery(const DomainName& name,
MdnsRecordChangedCallback* callback) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(callback);
+ OSP_DCHECK(dns_type != DnsType::kNSEC);
// Find and remove the callback.
int callbacks_for_key = 0;
@@ -132,154 +331,275 @@ void MdnsQuerier::StopQuery(const DomainName& name,
// be configurable by the caller.
}
+void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ // Get the ongoing queries and their callbacks.
+ std::vector<CallbackInfo> callbacks;
+ auto its = callbacks_.equal_range(name);
+ for (auto it = its.first; it != its.second; it++) {
+ callbacks.push_back(std::move(it->second));
+ }
+ callbacks_.erase(name);
+
+ // Remove all known questions and answers.
+ questions_.erase(name);
+ records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
+
+ // Restart the queries.
+ for (const auto& cb : callbacks) {
+ StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
+ }
+}
+
void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(message.type() == MessageType::Response);
- // TODO(crbug.com/openscreen/83): Drop answers and additional records if
- // answer records do not answer any existing questions
- // TODO(crbug.com/openscreen/83): Check authority records
- // TODO(crbug.com/openscreen/84): Cap size of cache, to avoid memory blowups
- // when publishers misbehave.
- ProcessRecords(message.answers());
- ProcessRecords(message.additional_records());
+ OSP_DVLOG << "Received mDNS Response message with "
+ << message.answers().size() << " answers and "
+ << message.additional_records().size()
+ << " additional records. Processing...";
+
+ // Add any records that are relevant for this querier.
+ bool found_relevant_records = false;
+ int processed_count = 0;
+ for (const MdnsRecord& record : message.answers()) {
+ if (ShouldAnswerRecordBeProcessed(record)) {
+ ProcessRecord(record);
+ OSP_DVLOG << "\tProcessing answer record for domain '"
+ << record.name().ToString() << "' of type '"
+ << record.dns_type() << "'...";
+ found_relevant_records = true;
+ processed_count++;
+ }
+ }
+
+ // If any of the message's answers are relevant, add all additional records.
+ // Else, since the message has already been received and parsed, use any
+ // individual records relevant to this querier to update the cache.
+ for (const MdnsRecord& record : message.additional_records()) {
+ if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
+ OSP_DVLOG << "\tProcessing additional record for domain '"
+ << record.name().ToString() << "' of type '"
+ << record.dns_type() << "'...";
+ ProcessRecord(record);
+ processed_count++;
+ }
+ }
+
+ OSP_DVLOG << "\tmDNS Response processed (" << processed_count
+ << " records accepted)!";
+
+ // TODO(crbug.com/openscreen/83): Check authority records.
}
-void MdnsQuerier::OnRecordExpired(const MdnsRecord& record) {
- OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
+ // First, accept the record if it's associated with an ongoing question.
+ const auto questions_range = questions_.equal_range(answer.name());
+ const auto it = std::find_if(
+ questions_range.first, questions_range.second,
+ [&answer](const auto& pair) {
+ return (pair.second->question().dns_type() == DnsType::kANY ||
+ IsNegativeResponseFor(answer,
+ pair.second->question().dns_type()) ||
+ pair.second->question().dns_type() == answer.dns_type()) &&
+ (pair.second->question().dns_class() == DnsClass::kANY ||
+ pair.second->question().dns_class() == answer.dns_class());
+ });
+ if (it != questions_range.second) {
+ return true;
+ }
+
+ // If not, check if it corresponds to an already existing record. This is
+ // required because records which are already stored may either have been
+ // received in an additional records section, or are associated with a query
+ // which is no longer active.
+ std::vector<DnsType> types{answer.dns_type()};
+ if (answer.dns_type() == DnsType::kNSEC) {
+ const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
+ types = nsec_rdata.types();
+ }
- ProcessCallbacks(record, RecordChangedEvent::kExpired);
-
- auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class() &&
- record.rdata() == tracked_record.rdata()) {
- records_.erase(entry);
- break;
+ for (DnsType type : types) {
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(answer.name(), type, answer.dns_class());
+ if (!trackers.empty()) {
+ return true;
}
}
+
+ // In all other cases, the record isn't relevant. Drop it.
+ return false;
+}
+
+void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
+ const MdnsRecord& record) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ if (!tracker->is_negative_response()) {
+ ProcessCallbacks(record, RecordChangedEvent::kExpired);
+ }
+
+ records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
+ return tracker == &it_tracker;
+ });
}
-void MdnsQuerier::ProcessRecords(const std::vector<MdnsRecord>& records) {
+void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- for (const MdnsRecord& record : records) {
+ // Get the types which the received record is associated with. In most cases
+ // this will only be the type of the provided record, but in the case of
+ // NSEC records this will be all records which the record dictates the
+ // nonexistence of.
+ std::vector<DnsType> types;
+ const std::vector<DnsType>* types_ptr = &types;
+ if (record.dns_type() == DnsType::kNSEC) {
+ const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
+ if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
+ DnsType::kANY) != nsec_rdata.types().end()) {
+ types_ptr = &kTranslatedNsecAnyQueryTypes;
+ } else {
+ types_ptr = &nsec_rdata.types();
+ }
+ } else {
+ types.push_back(record.dns_type());
+ }
+
+ // Apply the update for each type that the record is associated with.
+ for (DnsType dns_type : *types_ptr) {
switch (record.record_type()) {
case RecordType::kShared: {
- ProcessSharedRecord(record);
+ ProcessSharedRecord(record, dns_type);
break;
}
case RecordType::kUnique: {
- ProcessUniqueRecord(record);
+ ProcessUniqueRecord(record, dns_type);
break;
}
}
}
}
-void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) {
+void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
+ DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kShared);
- auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class() &&
- record.rdata() == tracked_record.rdata()) {
- // Already have this shared record, update the existing one.
- // This is a TTL only update since we've already checked that RDATA
- // matches. No notification is necessary on a TTL only update.
- // TODO(crbug.com/openscreen/87): Handle errors returned by Update().
- tracker->Update(record);
- return;
- }
+ // By design, NSEC records are never shared records.
+ if (record.dns_type() == DnsType::kNSEC) {
+ return;
+ }
+
+ // For any records updated, this host already has this shared record. Since
+ // the RDATA matches, this is only a TTL update.
+ auto check = [&record](const MdnsRecordTracker& tracker) {
+ return record.dns_type() == tracker.dns_type() &&
+ record.dns_class() == tracker.dns_class() &&
+ record.rdata() == tracker.rdata();
+ };
+ auto updated_count = records_.Update(record, std::move(check));
+
+ if (!updated_count) {
+ // Have never before seen this shared record, insert a new one.
+ AddRecord(record, dns_type);
+ ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
- // Have never before seen this shared record, insert a new one.
- AddRecord(record);
- ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
-void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) {
+void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
+ DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kUnique);
- int records_for_key = 0;
- auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- const MdnsRecord& tracked_record = entry->second->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class()) {
- ++records_for_key;
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(record.name(), dns_type, record.dns_class());
+ size_t num_records_for_key = trackers.size();
+
+ // Have not seen any records with this key before. This case is expected the
+ // first time a record is received.
+ if (num_records_for_key == size_t{0}) {
+ const bool will_exist = record.dns_type() != DnsType::kNSEC;
+ AddRecord(record, dns_type);
+ if (will_exist) {
+ ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
}
- if (records_for_key == 0) {
- // Have not seen any records with this key before.
- AddRecord(record);
- ProcessCallbacks(record, RecordChangedEvent::kCreated);
- } else if (records_for_key == 1) {
- // There's only one record with this key.
- MdnsRecordTracker* tracker = records_it.first->second.get();
- // TODO(crbug.com/openscreen/87): Handle errors returned by Update().
- ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record);
- if (result.is_value()) {
- switch (result.value()) {
- case MdnsRecordTracker::UpdateType::kGoodbye:
- tracker->ExpireSoon();
- break;
- case MdnsRecordTracker::UpdateType::kTTLOnly:
- // TTL has been updated. No action required.
- break;
- case MdnsRecordTracker::UpdateType::kRdata:
- // If RDATA on the record is different, notify that the record has
- // been updated.
- ProcessCallbacks(record, RecordChangedEvent::kUpdated);
- break;
- }
- }
- } else {
- // Multiple records with the same key. Expire all record with non-matching
- // RDATA. Update the record with the matching RDATA if it exists, otherwise
- // insert a new record.
- bool is_new_record = true;
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class()) {
- if (record.rdata() == tracked_record.rdata()) {
- is_new_record = false;
- // TODO(crbug.com/openscreen/87): Handle errors returned by Update().
- ErrorOr<MdnsRecordTracker::UpdateType> result =
- tracker->Update(record);
- if (result.is_value()) {
- switch (result.value()) {
- case MdnsRecordTracker::UpdateType::kGoodbye:
- tracker->ExpireSoon();
- break;
- case MdnsRecordTracker::UpdateType::kTTLOnly:
- // No notification is necessary on a TTL only update.
- break;
- case MdnsRecordTracker::UpdateType::kRdata:
- // Not possible - we already checked that the RDATA matches.
- OSP_NOTREACHED();
- break;
- }
- }
- } else {
- tracker->ExpireSoon();
- }
- }
+ // There is exactly one tracker associated with this key. This is the expected
+ // case when a record matching this one has already been seen.
+ else if (num_records_for_key == size_t{1}) {
+ ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
+ }
+
+ // Multiple records with the same key.
+ else {
+ ProcessMultiTrackedUniqueRecord(record, dns_type);
+ }
+}
+
+void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
+ const MdnsRecord& record,
+ const MdnsRecordTracker& tracker) {
+ const bool existed_previously = !tracker.is_negative_response();
+ const bool will_exist = record.dns_type() != DnsType::kNSEC;
+
+ // Calculate the callback to call on record update success while the old
+ // record still exists.
+ MdnsRecord record_for_callback = record;
+ if (existed_previously && !will_exist) {
+ record_for_callback =
+ MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
+ tracker.record_type(), tracker.ttl(), tracker.rdata());
+ }
+
+ auto on_rdata_change = [this, r = std::move(record_for_callback),
+ existed_previously,
+ will_exist](const MdnsRecordTracker& tracker) {
+ // If RDATA on the record is different, notify that the record has
+ // been updated.
+ if (existed_previously && will_exist) {
+ ProcessCallbacks(r, RecordChangedEvent::kUpdated);
+ } else if (existed_previously) {
+ // Do not expire the tracker, because it still holds an NSEC record.
+ ProcessCallbacks(r, RecordChangedEvent::kExpired);
+ } else if (will_exist) {
+ ProcessCallbacks(r, RecordChangedEvent::kCreated);
}
+ };
+
+ int updated_count = records_.Update(
+ record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
+ std::move(on_rdata_change));
+ OSP_DCHECK_EQ(updated_count, 1);
+}
- if (is_new_record) {
- // Did not find an existing record to update.
- AddRecord(record);
+void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
+ DnsType dns_type) {
+ auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
+ return tracker.dns_type() == dns_type &&
+ tracker.dns_class() == record.dns_class() &&
+ tracker.rdata() == record.rdata();
+ };
+ int update_count = records_.Update(
+ record, std::move(update_check),
+ [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
+ OSP_DCHECK_LE(update_count, 1);
+
+ auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
+ return tracker.dns_type() == dns_type &&
+ tracker.dns_class() == record.dns_class() &&
+ tracker.rdata() != record.rdata();
+ };
+ int expire_count =
+ records_.ExpireSoon(record.name(), std::move(expire_check));
+ OSP_DCHECK_GE(expire_count, 1);
+
+ // Did not find an existing record to update.
+ if (!update_count && !expire_count) {
+ AddRecord(record, dns_type);
+ if (record.dns_type() != DnsType::kNSEC) {
ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
}
@@ -299,25 +619,45 @@ void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
callback_info.callback->OnRecordChanged(record, event);
}
}
-
- // TODO(crbug.com/openscreen/83): Update known answers for relevant questions.
}
void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
- questions_.emplace(question.name(),
- std::make_unique<MdnsQuestionTracker>(
- std::move(question), sender_, task_runner_,
- now_function_, random_delay_));
+ auto tracker = std::make_unique<MdnsQuestionTracker>(
+ std::move(question), sender_, task_runner_, now_function_, random_delay_,
+ config_);
+ MdnsQuestionTracker* ptr = tracker.get();
+ questions_.emplace(question.name(), std::move(tracker));
+
+ // Let all records associated with this question know that there is a new
+ // query that can be used for their refresh.
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(question.name(), question.dns_type(), question.dns_class());
+ for (const MdnsRecordTracker& tracker : trackers) {
+ // NOTE: When the pointed to object is deleted, its dtor removes itself
+ // from all associated records.
+ ptr->AddAssociatedRecord(&tracker);
+ }
}
-void MdnsQuerier::AddRecord(const MdnsRecord& record) {
- auto expiration_callback = [this](const MdnsRecord& record) {
- MdnsQuerier::OnRecordExpired(record);
- };
- records_.emplace(record.name(),
- std::make_unique<MdnsRecordTracker>(
- std::move(record), sender_, task_runner_, now_function_,
- random_delay_, expiration_callback));
+void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
+ // Add the new record.
+ const auto& tracker = records_.StartTracking(record, type);
+
+ // Let all questions associated with this record know that there is a new
+ // record that answers them (for known answer suppression).
+ auto query_it = questions_.equal_range(record.name());
+ for (auto entry = query_it.first; entry != query_it.second; ++entry) {
+ const MdnsQuestion& query = entry->second->question();
+ const bool is_relevant_type =
+ type == DnsType::kANY || type == query.dns_type();
+ const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
+ record.dns_class() == query.dns_class();
+ if (is_relevant_type && is_relevant_class) {
+ // NOTE: When the pointed to object is deleted, its dtor removes itself
+ // from all associated queries.
+ entry->second->AddAssociatedRecord(&tracker);
+ }
+ }
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h
index 2e50978ffb2..1125815274d 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h
@@ -5,40 +5,44 @@
#ifndef DISCOVERY_MDNS_MDNS_QUERIER_H_
#define DISCOVERY_MDNS_MDNS_QUERIER_H_
+#include <list>
#include <map>
+#include "discovery/common/config.h"
+#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/mdns_trackers.h"
#include "platform/api/task_runner.h"
namespace openscreen {
namespace discovery {
class MdnsRandom;
-class MdnsReceiver;
class MdnsSender;
class MdnsQuestionTracker;
class MdnsRecordTracker;
+class ReportingClient;
-class MdnsQuerier {
+class MdnsQuerier : public MdnsReceiver::ResponseClient {
public:
- using ClockNowFunctionPtr = openscreen::platform::ClockNowFunctionPtr;
- using TaskRunner = openscreen::platform::TaskRunner;
-
MdnsQuerier(MdnsSender* sender,
MdnsReceiver* receiver,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay);
+ MdnsRandom* random_delay,
+ ReportingClient* reporting_client,
+ Config config);
MdnsQuerier(const MdnsQuerier& other) = delete;
MdnsQuerier(MdnsQuerier&& other) noexcept = delete;
MdnsQuerier& operator=(const MdnsQuerier& other) = delete;
MdnsQuerier& operator=(MdnsQuerier&& other) noexcept = delete;
- ~MdnsQuerier();
+ ~MdnsQuerier() override;
// Starts an mDNS query with the given name, DNS type, and DNS class. Updated
// records are passed to |callback|. The caller must ensure |callback|
// remains alive while it is registered with a query.
+ // NOTE: NSEC records cannot be queried for.
void StartQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
@@ -52,6 +56,11 @@ class MdnsQuerier {
DnsClass dns_class,
MdnsRecordChangedCallback* callback);
+ // Re-initializes the process of service discovery for the provided domain
+ // name. All ongoing queries for this domain are restarted and any previously
+ // received query results are discarded.
+ void ReinitializeQueries(const DomainName& name);
+
private:
struct CallbackInfo {
MdnsRecordChangedCallback* const callback;
@@ -59,25 +68,138 @@ class MdnsQuerier {
const DnsClass dns_class;
};
- // Callback passed to MdnsReceiver
- void OnMessageReceived(const MdnsMessage& message);
+ // Represents a Least Recently Used cache of MdnsRecordTrackers.
+ class RecordTrackerLruCache {
+ public:
+ using RecordTrackerConstRef =
+ std::reference_wrapper<const MdnsRecordTracker>;
+ using TrackerApplicableCheck =
+ std::function<bool(const MdnsRecordTracker&)>;
+ using TrackerChangeCallback = std::function<void(const MdnsRecordTracker&)>;
+
+ RecordTrackerLruCache(MdnsQuerier* querier,
+ MdnsSender* sender,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config);
+
+ // Returns all trackers with the associated |name| such that its type
+ // represents a type corresponding to |dns_type| and class corresponding to
+ // |dns_class|.
+ std::vector<RecordTrackerConstRef> Find(const DomainName& name);
+ std::vector<RecordTrackerConstRef> Find(const DomainName& name,
+ DnsType dns_type,
+ DnsClass dns_class);
+
+ // Calls ExpireSoon on all record trackers in the provided domain which
+ // match the provided applicability check. Returns the number of trackers
+ // marked for expiry.
+ int ExpireSoon(const DomainName& name, TrackerApplicableCheck check);
+
+ // Erases all record trackers in the provided domain which match the
+ // provided applicability check. Returns the number of trackers erased.
+ int Erase(const DomainName& name, TrackerApplicableCheck check);
+
+ // Updates all record trackers in the domain |record.name()| which match the
+ // provided applicability check using the provided record. Returns the
+ // number of records successfully updated.
+ int Update(const MdnsRecord& record, TrackerApplicableCheck check);
+ int Update(const MdnsRecord& record,
+ TrackerApplicableCheck check,
+ TrackerChangeCallback on_rdata_update);
+
+ // Creates a record tracker of the given type associated with the provided
+ // record.
+ const MdnsRecordTracker& StartTracking(MdnsRecord record, DnsType type);
+
+ size_t size() { return records_.size(); }
+
+ private:
+ using LruList = std::list<MdnsRecordTracker>;
+ using RecordMap = std::multimap<DomainName, LruList::iterator>;
+
+ void MoveToBeginning(RecordMap::iterator iterator);
+ void MoveToEnd(RecordMap::iterator iterator);
+
+ MdnsQuerier* const querier_;
+ MdnsSender* const sender_;
+ MdnsRandom* const random_delay_;
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+ ReportingClient* reporting_client_;
+ const Config& config_;
+
+ // List of RecordTracker instances used by this instance where the least
+ // recently updated element (or next to be deleted element) appears at the
+ // end of the list.
+ LruList lru_order_;
+
+ // A collection of active known record trackers, each is identified by
+ // domain name, DNS record type, and DNS record class. Multimap key is
+ // domain name only to allow easy support for wildcard processing for DNS
+ // record type and class and allow storing shared records that differ only
+ // in RDATA.
+ //
+ // MdnsRecordTracker instances are stored as unique_ptr so they are not
+ // moved around in memory when the collection is modified. This allows
+ // passing a pointer to MdnsQuestionTracker to a task running on the
+ // TaskRunner.
+ RecordMap records_;
+ };
+
+ friend class MdnsQuerierTest;
+
+ // MdnsReceiver::ResponseClient overrides.
+ void OnMessageReceived(const MdnsMessage& message) override;
+
+ // Expires the record tracker provided. This callback is passed to owned
+ // MdnsRecordTracker instances in |records_|.
+ void OnRecordExpired(const MdnsRecordTracker* tracker,
+ const MdnsRecord& record);
- // Callback passed to owned MdnsRecordTrackers
- void OnRecordExpired(const MdnsRecord& record);
+ // Determines whether a record received by this querier should be processed
+ // or dropped.
+ bool ShouldAnswerRecordBeProcessed(const MdnsRecord& answer);
- void ProcessRecords(const std::vector<MdnsRecord>& records);
- void ProcessSharedRecord(const MdnsRecord& record);
- void ProcessUniqueRecord(const MdnsRecord& record);
+ // Processes any record update, calling into the below methods as needed.
+ void ProcessRecord(const MdnsRecord& records);
+
+ // Processes a shared record update as a record of type |type|.
+ void ProcessSharedRecord(const MdnsRecord& record, DnsType type);
+
+ // Processes a unique record update as a record of type |type|.
+ void ProcessUniqueRecord(const MdnsRecord& record, DnsType type);
+
+ // Called when exactly one tracker is associated with a provided key.
+ // Determines the type of update being executed by this update call, then
+ // fires the appropriate callback.
+ void ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record,
+ const MdnsRecordTracker& tracker);
+
+ // Called when multiple records are associated with the same key. Expire all
+ // record with non-matching RDATA. Update the record with the matching RDATA
+ // if it exists, otherwise insert a new record.
+ void ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
+ DnsType dns_type);
+
+ // Calls all callbacks associated with the provided record.
void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event);
+ // Begins tracking the provided question.
void AddQuestion(const MdnsQuestion& question);
- void AddRecord(const MdnsRecord& record);
+
+ // Begins tracking the provided record.
+ void AddRecord(const MdnsRecord& record, DnsType type);
MdnsSender* const sender_;
MdnsReceiver* const receiver_;
TaskRunner* const task_runner_;
const ClockNowFunctionPtr now_function_;
MdnsRandom* const random_delay_;
+ ReportingClient* reporting_client_;
+ Config config_;
// A collection of active question trackers, each is uniquely identified by
// domain name, DNS record type, and DNS record class. Multimap key is domain
@@ -88,14 +210,8 @@ class MdnsQuerier {
// TaskRunner.
std::multimap<DomainName, std::unique_ptr<MdnsQuestionTracker>> questions_;
- // A collection of active known record trackers, each is identified by domain
- // name, DNS record type, and DNS record class. Multimap key is domain name
- // only to allow easy support for wildcard processing for DNS record type and
- // class and allow storing shared records that differ only in RDATA.
- // MdnsRecordTracker instances are stored as unique_ptr so they are not moved
- // around in memory when the collection is modified. This allows passing a
- // pointer to MdnsQuestionTracker to a task running on the TaskRunner.
- std::multimap<DomainName, std::unique_ptr<MdnsRecordTracker>> records_;
+ // Set of records tracked by this querier.
+ RecordTrackerLruCache records_;
// A collection of callbacks passed to StartQuery method. Each is identified
// by domain name, DNS record type, and DNS record class, but there can be
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc
index a24078221b1..b48c900f771 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier_unittest.cc
@@ -4,31 +4,31 @@
#include "discovery/mdns/mdns_querier.h"
+#include <memory>
+
+#include "discovery/common/config.h"
+#include "discovery/common/testing/mock_reporting_client.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
+#include "discovery/mdns/mdns_trackers.h"
#include "discovery/mdns/mdns_writer.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/base/udp_packet.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
+#include "platform/test/mock_udp_socket.h"
namespace openscreen {
namespace discovery {
-using openscreen::platform::Clock;
-using openscreen::platform::FakeClock;
-using openscreen::platform::FakeTaskRunner;
-using openscreen::platform::NetworkInterfaceIndex;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::UdpPacket;
-using openscreen::platform::UdpSocket;
using testing::_;
using testing::Args;
using testing::Invoke;
using testing::Return;
+using testing::StrictMock;
using testing::WithArgs;
// Only compare NAME, CLASS, TYPE and RDATA
@@ -40,27 +40,6 @@ ACTION_P(PartialCompareRecords, expected) {
EXPECT_TRUE(actual.rdata() == expected.rdata());
}
-class MockUdpSocket : public UdpSocket {
- public:
- MOCK_METHOD(bool, IsIPv4, (), (const, override));
- MOCK_METHOD(bool, IsIPv6, (), (const, override));
- MOCK_METHOD(IPEndpoint, GetLocalEndpoint, (), (const, override));
- MOCK_METHOD(void, Bind, (), (override));
- MOCK_METHOD(void,
- SetMulticastOutboundInterface,
- (NetworkInterfaceIndex),
- (override));
- MOCK_METHOD(void,
- JoinMulticastGroup,
- (const IPAddress&, NetworkInterfaceIndex),
- (override));
- MOCK_METHOD(void,
- SendMessage,
- (const void*, size_t, const IPEndpoint&),
- (override));
- MOCK_METHOD(void, SetDscp, (DscpMode), (override));
-};
-
class MockRecordChangedCallback : public MdnsRecordChangedCallback {
public:
MOCK_METHOD(void,
@@ -75,7 +54,6 @@ class MdnsQuerierTest : public testing::Test {
: clock_(Clock::now()),
task_runner_(&clock_),
sender_(&socket_),
- receiver_(&socket_),
record0_created_(DomainName{"testing", "local"},
DnsType::kA,
DnsClass::kIN,
@@ -105,38 +83,89 @@ class MdnsQuerierTest : public testing::Test {
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(0), // a goodbye record
- ARecordRdata(IPAddress{192, 168, 0, 1})) {
+ ARecordRdata(IPAddress{192, 168, 0, 1})),
+ record2_created_(DomainName{"testing", "local"},
+ DnsType::kAAAA,
+ DnsClass::kIN,
+ RecordType::kUnique,
+ std::chrono::seconds(120),
+ AAAARecordRdata(IPAddress{1, 2, 3, 4, 5, 6, 7, 8})),
+ nsec_record_created_(
+ DomainName{"testing", "local"},
+ DnsType::kNSEC,
+ DnsClass::kIN,
+ RecordType::kUnique,
+ std::chrono::seconds(120),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {
receiver_.Start();
}
std::unique_ptr<MdnsQuerier> CreateQuerier() {
return std::make_unique<MdnsQuerier>(&sender_, &receiver_, &task_runner_,
- &FakeClock::now, &random_);
+ &FakeClock::now, &random_,
+ &reporting_client_, config_);
}
protected:
- UdpPacket CreatePacketWithRecord(const MdnsRecord& record) {
+ UdpPacket CreatePacketWithRecords(
+ const std::vector<MdnsRecord::ConstRef>& records,
+ std::vector<MdnsRecord::ConstRef> additional_records) {
MdnsMessage message(CreateMessageId(), MessageType::Response);
- message.AddAnswer(record);
+ for (const MdnsRecord& record : records) {
+ message.AddAnswer(record);
+ }
+ for (const MdnsRecord& additional_record : additional_records) {
+ message.AddAdditionalRecord(additional_record);
+ }
UdpPacket packet(message.MaxWireSize());
MdnsWriter writer(packet.data(), packet.size());
- writer.Write(message);
+ EXPECT_TRUE(writer.Write(message));
packet.resize(writer.offset());
return packet;
}
+ UdpPacket CreatePacketWithRecords(
+ const std::vector<MdnsRecord::ConstRef>& records) {
+ return CreatePacketWithRecords(records, {});
+ }
+
+ UdpPacket CreatePacketWithRecord(const MdnsRecord& record) {
+ return CreatePacketWithRecords({MdnsRecord::ConstRef(record)});
+ }
+
+ // NSEC records are never exposed to outside callers, so the below methods are
+ // necessary to validate that they are functioning as expected.
+ bool ContainsRecord(MdnsQuerier* querier,
+ const MdnsRecord& record,
+ DnsType type = DnsType::kANY) {
+ auto record_trackers =
+ querier->records_.Find(record.name(), type, record.dns_class());
+
+ return std::find_if(record_trackers.begin(), record_trackers.end(),
+ [&record](const MdnsRecordTracker& tracker) {
+ return tracker.rdata() == record.rdata() &&
+ tracker.ttl() == record.ttl();
+ }) != record_trackers.end();
+ }
+
+ size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); }
+
+ Config config_;
FakeClock clock_;
FakeTaskRunner task_runner_;
testing::NiceMock<MockUdpSocket> socket_;
MdnsSender sender_;
MdnsReceiver receiver_;
MdnsRandom random_;
+ StrictMock<MockReportingClient> reporting_client_;
MdnsRecord record0_created_;
MdnsRecord record0_updated_;
MdnsRecord record0_deleted_;
MdnsRecord record1_created_;
MdnsRecord record1_deleted_;
+ MdnsRecord record2_created_;
+ MdnsRecord nsec_record_created_;
};
TEST_F(MdnsQuerierTest, UniqueRecordCreatedUpdatedDeleted) {
@@ -330,5 +359,257 @@ TEST_F(MdnsQuerierTest, SameCallerDifferentQuestions) {
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
}
+TEST_F(MdnsQuerierTest, ReinitializeQueries) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MockRecordChangedCallback callback;
+
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+
+ EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
+ .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
+
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+ // Receiving the same record should only reset TTL, no callback
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+ testing::Mock::VerifyAndClearExpectations(&receiver_);
+
+ // Queries should still be ongoing but all received records should have been
+ // deleted.
+ querier->ReinitializeQueries(DomainName{"testing", "local"});
+ EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
+ .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+ testing::Mock::VerifyAndClearExpectations(&receiver_);
+
+ // Reinitializing a different domain should not affect other queries.
+ querier->ReinitializeQueries(DomainName{"testing2", "local"});
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+}
+
+TEST_F(MdnsQuerierTest, MessagesForUnknownQueriesDropped) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MockRecordChangedCallback callback;
+
+ // Message for unknown query does not get processed.
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
+ querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
+ &callback);
+ testing::Mock::VerifyAndClearExpectations(&callback);
+
+ querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
+ &callback);
+
+ // Only known records from the message are processed.
+ EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
+ .Times(1);
+ receiver_.OnRead(
+ &socket_, CreatePacketWithRecords({record0_created_, record1_created_}));
+ querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
+ &callback);
+}
+
+TEST_F(MdnsQuerierTest, MessagesForKnownRecordsAllowed) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MockRecordChangedCallback callback;
+
+ // Store a message for a known query.
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+
+ // Stop the query and validate that record updates are still received.
+ querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
+ &callback);
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+
+ querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
+ &callback);
+
+ // Only known records from the message are processed.
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_updated_, RecordChangedEvent::kCreated))
+ .Times(1);
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+}
+
+TEST_F(MdnsQuerierTest, MessagesForUnknownKnownRecordsAllowsAdditionalRecords) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MockRecordChangedCallback callback;
+
+ // Store a message for a known query.
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated))
+ .Times(1);
+ receiver_.OnRead(&socket_, CreatePacketWithRecords({record1_created_},
+ {record0_created_}));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+}
+
+TEST_F(MdnsQuerierTest, CallbackNotCalledOnStartQueryForNsecRecords) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an NSEC record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+
+ // Start new query
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+}
+
+TEST_F(MdnsQuerierTest, ReceiveNsecRecordFansOutToEachType) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ MdnsRecord multi_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
+ DnsType::kSRV, DnsType::kAAAA));
+ auto packet = CreatePacketWithRecord(multi_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{3});
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kSRV));
+}
+
+TEST_F(MdnsQuerierTest, ReceiveNsecKAnyRecordFansOutToAllTypes) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ MdnsRecord any_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kANY));
+ auto packet = CreatePacketWithRecord(any_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{5});
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kAAAA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kSRV));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kTXT));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kPTR));
+}
+
+TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNsecRecordReplacesNonNsec) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ auto packet = CreatePacketWithRecord(record0_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+ ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
+ packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+}
+
+TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNonNsecRecordReplacesNsec) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ packet = CreatePacketWithRecord(record0_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+}
+
+TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MdnsRecord multi_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
+ DnsType::kSRV, DnsType::kAAAA));
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
+
+ packet = CreatePacketWithRecord(multi_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
+}
+
+TEST_F(MdnsQuerierTest, TestMaxRecordsRespected) {
+ config_.querier_max_records_cached = 1;
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
+ DnsClass::kIN, &callback);
+ querier->StartQuery(DomainName{"poking", "local"}, DnsType::kANY,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(record0_created_);
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
+ EXPECT_CALL(callback,
+ OnRecordChanged(record1_created_, RecordChangedEvent::kCreated));
+ packet = CreatePacketWithRecord(record1_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h
index b422f998683..bf33f1fdc80 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random.h
@@ -14,8 +14,6 @@ namespace discovery {
class MdnsRandom {
public:
- using Clock = openscreen::platform::Clock;
-
// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2
@@ -40,6 +38,11 @@ class MdnsRandom {
truncated_query_response_delay_(random_engine_)};
}
+ Clock::duration GetInitialProbeDelay() {
+ return std::chrono::milliseconds{
+ probe_initialization_delay_(random_engine_)};
+ }
+
private:
static constexpr int64_t kMinimumInitialQueryDelayMs = 20;
static constexpr int64_t kMaximumInitialQueryDelayMs = 120;
@@ -53,6 +56,9 @@ class MdnsRandom {
static constexpr int64_t kMinimumTruncatedQueryResponseDelayMs = 400;
static constexpr int64_t kMaximumTruncatedQueryResponseDelayMs = 500;
+ static constexpr int64_t kMinimumProbeInitializationDelayMs = 0;
+ static constexpr int64_t kMaximumProbeInitializationDelayMs = 250;
+
std::default_random_engine random_engine_{std::random_device{}()};
std::uniform_int_distribution<int64_t> initial_query_delay_{
kMinimumInitialQueryDelayMs, kMaximumInitialQueryDelayMs};
@@ -63,6 +69,8 @@ class MdnsRandom {
std::uniform_int_distribution<int64_t> truncated_query_response_delay_{
kMinimumTruncatedQueryResponseDelayMs,
kMaximumTruncatedQueryResponseDelayMs};
+ std::uniform_int_distribution<int64_t> probe_initialization_delay_{
+ kMinimumProbeInitializationDelayMs, kMaximumProbeInitializationDelayMs};
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc
index 1ba58c90bdf..e0ad261154d 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_random_unittest.cc
@@ -10,8 +10,6 @@
namespace openscreen {
namespace discovery {
-using openscreen::platform::Clock;
-
namespace {
constexpr int kIterationCount = 100;
}
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc
index 21174e742ef..e7de282769a 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc
@@ -7,10 +7,25 @@
#include <algorithm>
#include <utility>
+#include "discovery/mdns/public/mdns_constants.h"
#include "util/logging.h"
namespace openscreen {
namespace discovery {
+namespace {
+
+bool TryParseDnsType(uint16_t to_parse, DnsType* type) {
+ auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
+ static_cast<DnsType>(to_parse));
+ if (it == kSupportedDnsTypes.end()) {
+ return false;
+ }
+
+ *type = *it;
+ return true;
+}
+
+} // namespace
bool MdnsReader::Read(TxtRecordRdata::Entry* out) {
Cursor cursor(this);
@@ -52,7 +67,12 @@ bool MdnsReader::Read(DomainName* out) {
bytes_processed <= length()) {
const uint8_t label_type = ReadBigEndian<uint8_t>(position);
if (IsTerminationLabel(label_type)) {
- *out = DomainName(labels);
+ ErrorOr<DomainName> domain =
+ DomainName::TryCreate(labels.begin(), labels.end());
+ if (domain.is_error()) {
+ return false;
+ }
+ *out = std::move(domain.value());
if (!bytes_consumed) {
bytes_consumed = position + sizeof(uint8_t) - current();
}
@@ -100,7 +120,12 @@ bool MdnsReader::Read(RawRecordRdata* out) {
if (Read(&record_length)) {
std::vector<uint8_t> buffer(record_length);
if (Read(buffer.size(), buffer.data())) {
- *out = RawRecordRdata(std::move(buffer));
+ ErrorOr<RawRecordRdata> rdata =
+ RawRecordRdata::TryCreate(std::move(buffer));
+ if (rdata.is_error()) {
+ return false;
+ }
+ *out = std::move(rdata.value());
cursor.Commit();
return true;
}
@@ -189,11 +214,47 @@ bool MdnsReader::Read(TxtRecordRdata* out) {
if (cursor.delta() != sizeof(record_length) + record_length) {
return false;
}
- *out = TxtRecordRdata(std::move(texts));
+ ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
+ if (rdata.is_error()) {
+ return false;
+ }
+ *out = std::move(rdata.value());
cursor.Commit();
return true;
}
+bool MdnsReader::Read(NsecRecordRdata* out) {
+ OSP_DCHECK(out);
+ Cursor cursor(this);
+
+ const uint8_t* start_position = current();
+ uint16_t record_length;
+ DomainName next_record_name;
+ if (!Read(&record_length) || !Read(&next_record_name)) {
+ return false;
+ }
+
+ // Calculate the next record name length. This may not be equal to the length
+ // of |next_record_name| due to domain name compression.
+ const int encoded_next_name_length =
+ current() - start_position - sizeof(record_length);
+ const int remaining_length = record_length - encoded_next_name_length;
+ if (remaining_length <= 0) {
+ // This means either the length is invalid or the NSEC record has no
+ // associated types.
+ return false;
+ }
+
+ std::vector<DnsType> types;
+ if (Read(&types, remaining_length)) {
+ *out = NsecRecordRdata(std::move(next_record_name), std::move(types));
+ cursor.Commit();
+ return true;
+ }
+
+ return false;
+}
+
bool MdnsReader::Read(MdnsRecord* out) {
OSP_DCHECK(out);
Cursor cursor(this);
@@ -204,9 +265,14 @@ bool MdnsReader::Read(MdnsRecord* out) {
Rdata rdata;
if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) &&
Read(static_cast<DnsType>(type), &rdata)) {
- *out = MdnsRecord(std::move(name), static_cast<DnsType>(type),
- GetDnsClass(rrclass), GetRecordType(rrclass),
- std::chrono::seconds(ttl), std::move(rdata));
+ ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate(
+ std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass),
+ GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata));
+ if (record.is_error()) {
+ return false;
+ }
+ *out = std::move(record.value());
+
cursor.Commit();
return true;
}
@@ -220,8 +286,14 @@ bool MdnsReader::Read(MdnsQuestion* out) {
uint16_t type;
uint16_t rrclass;
if (Read(&name) && Read(&type) && Read(&rrclass)) {
- *out = MdnsQuestion(std::move(name), static_cast<DnsType>(type),
- GetDnsClass(rrclass), GetResponseType(rrclass));
+ ErrorOr<MdnsQuestion> question =
+ MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type),
+ GetDnsClass(rrclass), GetResponseType(rrclass));
+ if (question.is_error()) {
+ return false;
+ }
+ *out = std::move(question.value());
+
cursor.Commit();
return true;
}
@@ -244,8 +316,18 @@ bool MdnsReader::Read(MdnsMessage* out) {
// One way to do this is to change the method signature to return
// ErrorOr<MdnsMessage> and return different error codes for failure to read
// and for messages that were read successfully but are non-conforming.
- *out = MdnsMessage(header.id, GetMessageType(header.flags), questions,
- answers, authority_records, additional_records);
+ ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate(
+ header.id, GetMessageType(header.flags), questions, answers,
+ authority_records, additional_records);
+ if (message.is_error()) {
+ return false;
+ }
+ *out = std::move(message.value());
+
+ if (IsMessageTruncated(header.flags)) {
+ out->set_truncated();
+ }
+
cursor.Commit();
return true;
}
@@ -278,7 +360,11 @@ bool MdnsReader::Read(DnsType type, Rdata* out) {
return Read<PtrRecordRdata>(out);
case DnsType::kTXT:
return Read<TxtRecordRdata>(out);
+ case DnsType::kNSEC:
+ return Read<NsecRecordRdata>(out);
default:
+ OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
+ type) == kSupportedDnsTypes.end());
return Read<RawRecordRdata>(out);
}
}
@@ -295,5 +381,68 @@ bool MdnsReader::Read(Header* out) {
return false;
}
+bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) {
+ OSP_DCHECK(out);
+ Cursor cursor(this);
+
+ // Continue reading bitmaps until the entire input is read. If we have gone
+ // past the end of the record, it's malformed input so fail.
+ *out = std::vector<DnsType>();
+ int processed_bytes = 0;
+ while (processed_bytes < remaining_size) {
+ NsecBitMapField bitmap;
+ if (!Read(&bitmap)) {
+ return false;
+ }
+
+ processed_bytes += bitmap.bitmap_length + 2;
+ if (processed_bytes > remaining_size) {
+ return false;
+ }
+
+ // The ith bit of the bitmap represents DnsType with value i, shifted
+ // a multiple of 0x100 according to the window.
+ for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) {
+ int current_byte = i / 8;
+ uint8_t bitmask = 0x80 >> i % 8;
+
+ // If this bit flag represents a type we support, add it to the vector.
+ // Else, we won't be able to use it later on in the code anyway, so drop
+ // it.
+ DnsType type;
+ uint16_t type_index = i | (bitmap.window_block << 8);
+ if ((bitmap.bitmap[current_byte] & bitmask) &&
+ TryParseDnsType(type_index, &type)) {
+ out->push_back(type);
+ }
+ }
+ }
+
+ cursor.Commit();
+ return true;
+}
+
+bool MdnsReader::Read(NsecBitMapField* out) {
+ OSP_DCHECK(out);
+ Cursor cursor(this);
+
+ // Read the window and bitmap length, then one byte for each byte called out
+ // by the length.
+ if (Read(&out->window_block) && Read(&out->bitmap_length)) {
+ if (out->bitmap_length == 0 || out->bitmap_length > 32) {
+ return false;
+ }
+
+ out->bitmap = current();
+ if (!Skip(out->bitmap_length)) {
+ return false;
+ }
+ cursor.Commit();
+ return true;
+ }
+
+ return false;
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h
index defcf33abc6..ecf2aafacad 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.h
@@ -30,6 +30,7 @@ class MdnsReader : public BigEndianReader {
bool Read(AAAARecordRdata* out);
bool Read(PtrRecordRdata* out);
bool Read(TxtRecordRdata* out);
+ bool Read(NsecRecordRdata* out);
// Reads a DNS resource record with its RDATA.
// The correct type of RDATA to be read is determined by the type
// specified in the record.
@@ -40,9 +41,17 @@ class MdnsReader : public BigEndianReader {
bool Read(MdnsMessage* out);
private:
+ struct NsecBitMapField {
+ uint8_t window_block;
+ uint8_t bitmap_length;
+ const uint8_t* bitmap;
+ };
+
bool Read(IPAddress::Version version, IPAddress* out);
bool Read(DnsType type, Rdata* out);
bool Read(Header* out);
+ bool Read(std::vector<DnsType>* types, int remaining_length);
+ bool Read(NsecBitMapField* out);
template <class ItemType>
bool Read(uint16_t count, std::vector<ItemType>* out) {
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc
new file mode 100644
index 00000000000..d2e2eb72cfb
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_fuzztest.cc
@@ -0,0 +1,12 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_reader.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ openscreen::discovery::MdnsReader reader(data, size);
+ openscreen::discovery::MdnsMessage message;
+ reader.Read(&message);
+ return 0;
+}
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc
index 8f4259df45c..a802162d3f2 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader_unittest.cc
@@ -344,25 +344,64 @@ TEST(MdnsReaderTest, ReadTxtRecordRdata_EmptyEntries) {
MakeTxtRecord({"foo=1", "bar=2"}));
}
-TEST(MdnsReaderTest, ReadTxtRecordRdata_TooShort) {
+TEST(MdnsReaderTest, ReadNsecRecordRdata) {
// clang-format off
- constexpr uint8_t kTxtRecordRdata[] = {
- 0x00, 0x0C, // RDLENGTH = 12
- 0x05, 'f', 'o', 'o', '=', '1',
+ constexpr uint8_t kExpectedRdata[] = {
+ 0x00, 0x20, // RDLENGTH = 32
+ 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e',
+ 0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes have the following values:
+ // - kA = 1 (encoded in byte 1)
+ // kTXT = 16 (encoded in byte 3)
+ // - kSRV = 33 (encoded in byte 5)
+ // - kNSEC = 47 (encoded in 6 bytes)
+ // - The largest of these is 47, so 6 bytes are needed to encode this data.
+ // So the full encoded version is:
+ // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001
+ // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 |
+ 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01
};
// clang-format on
- TestReadEntryFails<TxtRecordRdata>(kTxtRecordRdata, sizeof(kTxtRecordRdata));
+ TestReadEntrySucceeds(
+ kExpectedRdata, sizeof(kExpectedRdata),
+ NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA,
+ DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC));
}
-TEST(MdnsReaderTest, ReadTxtRecordRdata_WrongLength) {
+TEST(MdnsReaderTest, ReadNsecRecordRdata_TooShort) {
// clang-format off
- constexpr uint8_t kTxtRecordRdata[] = {
- 0x00, 0x0F, // Wrong length specified
- 0x05, 'f', 'o', 'o', '=', '1',
- 0x05, 'b', 'a', 'r', '=', '2',
+ constexpr uint8_t kNsecRecordRdata[] = {
+ 0x00, 0x20, // RDLENGTH = 32
+ 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e',
+ 0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x06, 0x40, 0x00
+ };
+ // clang-format on
+ TestReadEntryFails<NsecRecordRdata>(kNsecRecordRdata,
+ sizeof(kNsecRecordRdata));
+}
+
+TEST(MdnsReaderTest, ReadNsecRecordRdata_WrongLength) {
+ // clang-format off
+ constexpr uint8_t kNsecRecordRdata[] = {
+ 0x00, 0x21, // RDLENGTH = 33
+ 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e',
+ 0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01
};
// clang-format on
- TestReadEntryFails<TxtRecordRdata>(kTxtRecordRdata, sizeof(kTxtRecordRdata));
+ TestReadEntryFails<NsecRecordRdata>(kNsecRecordRdata,
+ sizeof(kNsecRecordRdata));
}
TEST(MdnsReaderTest, ReadMdnsRecord_ARecordRdata) {
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc
index 9245b8c10a3..8dc9fdd1fa8 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc
@@ -7,19 +7,19 @@
#include "discovery/mdns/mdns_reader.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace discovery {
-MdnsReceiver::MdnsReceiver(UdpSocket* socket) : socket_(socket) {
- OSP_DCHECK(socket_);
-}
+MdnsReceiver::ResponseClient::~ResponseClient() = default;
+
+MdnsReceiver::MdnsReceiver() = default;
MdnsReceiver::~MdnsReceiver() {
if (state_ == State::kRunning) {
Stop();
}
+
+ OSP_DCHECK(response_clients_.empty());
}
void MdnsReceiver::SetQueryCallback(
@@ -30,13 +30,20 @@ void MdnsReceiver::SetQueryCallback(
query_callback_ = callback;
}
-void MdnsReceiver::SetResponseCallback(
- std::function<void(const MdnsMessage&)> callback) {
- // This check verifies that either new or stored callback has a target. It
- // will fail in case multiple objects try to set or clear the callback.
- OSP_DCHECK(static_cast<bool>(response_callback_) !=
- static_cast<bool>(callback));
- response_callback_ = callback;
+void MdnsReceiver::AddResponseCallback(ResponseClient* callback) {
+ auto it =
+ std::find(response_clients_.begin(), response_clients_.end(), callback);
+ OSP_DCHECK(it == response_clients_.end());
+
+ response_clients_.push_back(callback);
+}
+
+void MdnsReceiver::RemoveResponseCallback(ResponseClient* callback) {
+ auto it =
+ std::find(response_clients_.begin(), response_clients_.end(), callback);
+ OSP_DCHECK(it != response_clients_.end());
+
+ response_clients_.erase(it);
}
void MdnsReceiver::Start() {
@@ -55,7 +62,7 @@ void MdnsReceiver::OnRead(UdpSocket* socket,
UdpPacket packet = std::move(packet_or_error.value());
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsReceiver::OnRead");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsReceiver::OnRead");
MdnsReader reader(packet.data(), packet.size());
MdnsMessage message;
if (!reader.Read(&message)) {
@@ -63,25 +70,20 @@ void MdnsReceiver::OnRead(UdpSocket* socket,
}
if (message.type() == MessageType::Response) {
- if (response_callback_) {
- response_callback_(message);
+ for (ResponseClient* client : response_clients_) {
+ client->OnMessageReceived(message);
+ }
+ if (response_clients_.empty()) {
+ OSP_DVLOG << "Response message dropped. No response client registered...";
}
} else {
if (query_callback_) {
query_callback_(message, packet.source());
+ } else {
+ OSP_DVLOG << "Query message dropped. No query client registered...";
}
}
}
-void MdnsReceiver::OnError(UdpSocket* socket, Error error) {
- // This method should never be called for MdnsReciever.
- OSP_UNIMPLEMENTED();
-}
-
-void MdnsReceiver::OnSendError(UdpSocket* socket, Error error) {
- // This method should never be called for MdnsReciever.
- OSP_UNIMPLEMENTED();
-}
-
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h
index ca02543e35b..64e1a93da6a 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.h
@@ -5,6 +5,8 @@
#ifndef DISCOVERY_MDNS_MDNS_RECEIVER_H_
#define DISCOVERY_MDNS_MDNS_RECEIVER_H_
+#include <functional>
+
#include "platform/api/udp_socket.h"
#include "platform/base/error.h"
#include "platform/base/udp_packet.h"
@@ -14,24 +16,29 @@ namespace discovery {
class MdnsMessage;
-class MdnsReceiver : openscreen::platform::UdpSocket::Client {
+class MdnsReceiver {
public:
- using UdpPacket = openscreen::platform::UdpPacket;
- using UdpSocket = openscreen::platform::UdpSocket;
+ class ResponseClient {
+ public:
+ virtual ~ResponseClient();
+
+ virtual void OnMessageReceived(const MdnsMessage& message) = 0;
+ };
// MdnsReceiver does not own |socket| and |delegate|
// and expects that the lifetime of these objects exceeds the lifetime of
// MdnsReceiver.
- explicit MdnsReceiver(UdpSocket* socket);
+ MdnsReceiver();
MdnsReceiver(const MdnsReceiver& other) = delete;
MdnsReceiver(MdnsReceiver&& other) noexcept = delete;
MdnsReceiver& operator=(const MdnsReceiver& other) = delete;
MdnsReceiver& operator=(MdnsReceiver&& other) noexcept = delete;
- ~MdnsReceiver() override;
+ ~MdnsReceiver();
void SetQueryCallback(
std::function<void(const MdnsMessage&, const IPEndpoint& src)> callback);
- void SetResponseCallback(std::function<void(const MdnsMessage&)> callback);
+ void AddResponseCallback(ResponseClient* callback);
+ void RemoveResponseCallback(ResponseClient* callback);
// The receiver can be started and stopped multiple times.
// Start and Stop are both synchronous calls. When MdnsReceiver has not yet
@@ -40,10 +47,7 @@ class MdnsReceiver : openscreen::platform::UdpSocket::Client {
void Start();
void Stop();
- // UdpSocket::Client overrides.
- void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
- void OnError(UdpSocket* socket, Error error) override;
- void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet);
private:
enum class State {
@@ -51,11 +55,11 @@ class MdnsReceiver : openscreen::platform::UdpSocket::Client {
kRunning,
};
- UdpSocket* const socket_;
std::function<void(const MdnsMessage&, const IPEndpoint& src)>
query_callback_;
- std::function<void(const MdnsMessage&)> response_callback_;
State state_ = State::kStopped;
+
+ std::vector<ResponseClient*> response_clients_;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc
index 5ef2162584b..df7fae9e80d 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver_unittest.cc
@@ -4,6 +4,10 @@
#include "discovery/mdns/mdns_receiver.h"
+#include <memory>
+#include <utility>
+#include <vector>
+
#include "discovery/mdns/mdns_records.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -13,13 +17,10 @@
namespace openscreen {
namespace discovery {
-using openscreen::platform::FakeUdpSocket;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::UdpPacket;
using testing::_;
using testing::Return;
-class MockMdnsReceiverDelegate {
+class MockMdnsReceiverDelegate : public MdnsReceiver::ResponseClient {
public:
MOCK_METHOD(void, OnMessageReceived, (const MdnsMessage&));
};
@@ -42,10 +43,9 @@ TEST(MdnsReceiverTest, ReceiveQuery) {
};
// clang-format on
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV4);
+ FakeUdpSocket socket;
MockMdnsReceiverDelegate delegate;
- MdnsReceiver receiver(socket_info.get());
+ MdnsReceiver receiver;
receiver.SetQueryCallback(
[&delegate](const MdnsMessage& message, const IPEndpoint& endpoint) {
delegate.OnMessageReceived(message);
@@ -66,7 +66,7 @@ TEST(MdnsReceiverTest, ReceiveQuery) {
// Imitate a call to OnRead from NetworkRunner by calling it manually here
EXPECT_CALL(delegate, OnMessageReceived(message)).Times(1);
- receiver.OnRead(socket_info.get(), std::move(packet));
+ receiver.OnRead(&socket, std::move(packet));
receiver.Stop();
}
@@ -91,19 +91,16 @@ TEST(MdnsReceiverTest, ReceiveResponse) {
0xac, 0x00, 0x00, 0x01, // 172.0.0.1
};
- constexpr uint8_t kIPv6AddressBytes[] = {
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ constexpr uint16_t kIPv6AddressHextets[] = {
+ 0xfe80, 0x0000, 0x0000, 0x0000,
+ 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
// clang-format on
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV6);
+ FakeUdpSocket socket;
MockMdnsReceiverDelegate delegate;
- MdnsReceiver receiver(socket_info.get());
- receiver.SetResponseCallback([&delegate](const MdnsMessage& message) {
- delegate.OnMessageReceived(message);
- });
+ MdnsReceiver receiver;
+ receiver.AddResponseCallback(&delegate);
receiver.Start();
MdnsRecord record(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
@@ -116,16 +113,17 @@ TEST(MdnsReceiverTest, ReceiveResponse) {
packet.assign(kResponseBytes.data(),
kResponseBytes.data() + kResponseBytes.size());
packet.set_source(
- IPEndpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337});
+ IPEndpoint{.address = IPAddress(kIPv6AddressHextets), .port = 31337});
packet.set_destination(
IPEndpoint{.address = IPAddress(kDefaultMulticastGroupIPv6),
.port = kDefaultMulticastPort});
// Imitate a call to OnRead from NetworkRunner by calling it manually here
EXPECT_CALL(delegate, OnMessageReceived(message)).Times(1);
- receiver.OnRead(socket_info.get(), std::move(packet));
+ receiver.OnRead(&socket, std::move(packet));
receiver.Stop();
+ receiver.RemoveResponseCallback(&delegate);
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h
index d5d340b9efa..c8c02d69797 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_record_changed_callback.h
@@ -5,6 +5,8 @@
#ifndef DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_
#define DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_
+#include "util/logging.h"
+
namespace openscreen {
namespace discovery {
@@ -23,6 +25,21 @@ class MdnsRecordChangedCallback {
RecordChangedEvent event) = 0;
};
+inline std::ostream& operator<<(std::ostream& output,
+ RecordChangedEvent event) {
+ switch (event) {
+ case RecordChangedEvent::kCreated:
+ return output << "Create";
+ case RecordChangedEvent::kUpdated:
+ return output << "Update";
+ case RecordChangedEvent::kExpired:
+ return output << "Expiry";
+ }
+
+ OSP_NOTREACHED();
+ return output;
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc
index 69f76f24e8c..a04c2694165 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc
@@ -4,18 +4,23 @@
#include "discovery/mdns/mdns_records.h"
-#include <atomic>
#include <cctype>
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
+#include "discovery/mdns/mdns_writer.h"
namespace openscreen {
namespace discovery {
namespace {
+constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();
+
+constexpr size_t kMaxMessageFieldEntryCount =
+ std::numeric_limits<uint16_t>::max();
+
inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
size_t i = 0;
for (; i < x.size(); i++) {
@@ -33,6 +38,54 @@ inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
return i == y.size() ? 0 : -1;
}
+template <typename RDataType>
+bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
+ const RDataType& lhs_cast = absl::get<RDataType>(lhs);
+ const RDataType& rhs_cast = absl::get<RDataType>(rhs);
+
+ size_t lhs_size = lhs_cast.MaxWireSize();
+ size_t rhs_size = rhs_cast.MaxWireSize();
+ size_t min_size = std::min(lhs_size, rhs_size);
+
+ uint8_t lhs_bytes[lhs_size];
+ uint8_t rhs_bytes[rhs_size];
+ MdnsWriter lhs_writer(lhs_bytes, lhs_size);
+ MdnsWriter rhs_writer(rhs_bytes, rhs_size);
+
+ lhs_writer.Write(lhs_cast);
+ rhs_writer.Write(rhs_cast);
+ for (size_t i = 0; i < min_size; i++) {
+ if (lhs_bytes[i] != rhs_bytes[i]) {
+ return lhs_bytes[i] > rhs_bytes[i];
+ }
+ }
+
+ if (lhs_size == rhs_size) {
+ return false;
+ }
+
+ return lhs_size > rhs_size;
+}
+
+bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
+ switch (type) {
+ case DnsType::kA:
+ return IsGreaterThan<ARecordRdata>(lhs, rhs);
+ case DnsType::kPTR:
+ return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
+ case DnsType::kTXT:
+ return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
+ case DnsType::kAAAA:
+ return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
+ case DnsType::kSRV:
+ return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
+ case DnsType::kNSEC:
+ return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
+ default:
+ return IsGreaterThan<RawRecordRdata>(lhs, rhs);
+ }
+}
+
} // namespace
bool IsValidDomainLabel(absl::string_view label) {
@@ -51,6 +104,9 @@ DomainName::DomainName(const std::vector<absl::string_view>& labels)
DomainName::DomainName(std::initializer_list<absl::string_view> labels)
: DomainName(labels.begin(), labels.end()) {}
+DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
+ : max_wire_size_(max_wire_size), labels_(std::move(labels)) {}
+
DomainName::DomainName(const DomainName& other) = default;
DomainName::DomainName(DomainName&& other) = default;
@@ -112,12 +168,21 @@ size_t DomainName::MaxWireSize() const {
return max_wire_size_;
}
+// static
+ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
+ if (rdata.size() > kMaxRawRecordSize) {
+ return Error::Code::kIndexOutOfBounds;
+ } else {
+ return RawRecordRdata(std::move(rdata));
+ }
+}
+
RawRecordRdata::RawRecordRdata() = default;
RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
: rdata_(std::move(rdata)) {
// Ensure RDATA length does not exceed the maximum allowed.
- OSP_DCHECK(rdata_.size() <= std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
}
RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
@@ -180,8 +245,10 @@ size_t SrvRecordRdata::MaxWireSize() const {
ARecordRdata::ARecordRdata() = default;
-ARecordRdata::ARecordRdata(IPAddress ipv4_address)
- : ipv4_address_(std::move(ipv4_address)) {
+ARecordRdata::ARecordRdata(IPAddress ipv4_address,
+ NetworkInterfaceIndex interface_index)
+ : ipv4_address_(std::move(ipv4_address)),
+ interface_index_(interface_index) {
OSP_CHECK(ipv4_address_.IsV4());
}
@@ -194,7 +261,8 @@ ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;
ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;
bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
- return ipv4_address_ == rhs.ipv4_address_;
+ return ipv4_address_ == rhs.ipv4_address_ &&
+ interface_index_ == rhs.interface_index_;
}
bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
@@ -208,8 +276,10 @@ size_t ARecordRdata::MaxWireSize() const {
AAAARecordRdata::AAAARecordRdata() = default;
-AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address)
- : ipv6_address_(std::move(ipv6_address)) {
+AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
+ NetworkInterfaceIndex interface_index)
+ : ipv6_address_(std::move(ipv6_address)),
+ interface_index_(interface_index) {
OSP_CHECK(ipv6_address_.IsV6());
}
@@ -223,7 +293,8 @@ AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;
bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
- return ipv6_address_ == rhs.ipv6_address_;
+ return ipv6_address_ == rhs.ipv6_address_ &&
+ interface_index_ == rhs.interface_index_;
}
bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
@@ -261,23 +332,39 @@ size_t PtrRecordRdata::MaxWireSize() const {
return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
}
-TxtRecordRdata::TxtRecordRdata() = default;
-
-TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
+// static
+ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
+ std::vector<std::string> str_texts;
+ size_t max_wire_size = 3;
if (texts.size() > 0) {
- texts_.reserve(texts.size());
+ str_texts.reserve(texts.size());
// max_wire_size includes uint16_t record length field.
- max_wire_size_ = sizeof(uint16_t);
+ max_wire_size = sizeof(uint16_t);
for (const auto& text : texts) {
- OSP_DCHECK(!text.empty());
- texts_.push_back(
+ if (text.empty()) {
+ return Error::Code::kParameterInvalid;
+ }
+ str_texts.push_back(
std::string(reinterpret_cast<const char*>(text.data()), text.size()));
// Include the length byte in the size calculation.
- max_wire_size_ += text.size() + 1;
+ max_wire_size += text.size() + 1;
}
}
+ return TxtRecordRdata(std::move(str_texts), max_wire_size);
}
+TxtRecordRdata::TxtRecordRdata() = default;
+
+TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
+ ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
+ OSP_DCHECK(rdata.is_value());
+ *this = std::move(rdata.value());
+}
+
+TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
+ size_t max_wire_size)
+ : max_wire_size_(max_wire_size), texts_(std::move(texts)) {}
+
TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;
TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) = default;
@@ -298,6 +385,91 @@ size_t TxtRecordRdata::MaxWireSize() const {
return max_wire_size_;
}
+NsecRecordRdata::NsecRecordRdata() = default;
+
+NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
+ std::vector<DnsType> types)
+ : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
+ // Sort the types_ array for easier comparison later.
+ std::sort(types_.begin(), types_.end());
+
+ // Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
+ std::vector<uint8_t> block_contents;
+ uint8_t current_block = 0;
+ for (auto type : types_) {
+ const uint16_t type_int = static_cast<uint16_t>(type);
+ const uint8_t block = static_cast<uint8_t>(type_int >> 8);
+ const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
+ const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits.
+ const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits.
+
+ // If the block has changed, write the previous block's info and all of its
+ // contents to the |encoded_types_| vector.
+ if (block > current_block) {
+ if (!block_contents.empty()) {
+ encoded_types_.push_back(current_block);
+ encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
+ encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
+ block_contents.end());
+ }
+ block_contents = std::vector<uint8_t>();
+ current_block = block;
+ }
+
+ // Make sure |block_contents| is large enough to hold the bit representing
+ // the new type , then set it.
+ if (block_contents.size() <= byte_bit_is_at) {
+ block_contents.insert(block_contents.end(),
+ byte_bit_is_at - block_contents.size() + 1, 0x00);
+ }
+
+ block_contents[byte_bit_is_at] |= byte_mask;
+ }
+
+ if (!block_contents.empty()) {
+ encoded_types_.push_back(current_block);
+ encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
+ encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
+ block_contents.end());
+ }
+}
+
+NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;
+
+NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) = default;
+
+NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
+ default;
+
+NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;
+
+bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
+ return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
+}
+
+bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
+ return !(*this == rhs);
+}
+
+size_t NsecRecordRdata::MaxWireSize() const {
+ return next_domain_name_.MaxWireSize() + encoded_types_.size();
+}
+
+// static
+ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
+ DnsType dns_type,
+ DnsClass dns_class,
+ RecordType record_type,
+ std::chrono::seconds ttl,
+ Rdata rdata) {
+ if (!IsValidConfig(name, dns_type, ttl, rdata)) {
+ return Error::Code::kParameterInvalid;
+ } else {
+ return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
+ std::move(rdata));
+ }
+}
+
MdnsRecord::MdnsRecord() = default;
MdnsRecord::MdnsRecord(DomainName name,
@@ -312,19 +484,7 @@ MdnsRecord::MdnsRecord(DomainName name,
record_type_(record_type),
ttl_(ttl),
rdata_(std::move(rdata)) {
- OSP_DCHECK(!name_.empty());
- OSP_DCHECK_LE(ttl_.count(), std::numeric_limits<uint32_t>::max());
- OSP_DCHECK((dns_type == DnsType::kSRV &&
- absl::holds_alternative<SrvRecordRdata>(rdata_)) ||
- (dns_type == DnsType::kA &&
- absl::holds_alternative<ARecordRdata>(rdata_)) ||
- (dns_type == DnsType::kAAAA &&
- absl::holds_alternative<AAAARecordRdata>(rdata_)) ||
- (dns_type == DnsType::kPTR &&
- absl::holds_alternative<PtrRecordRdata>(rdata_)) ||
- (dns_type == DnsType::kTXT &&
- absl::holds_alternative<TxtRecordRdata>(rdata_)) ||
- absl::holds_alternative<RawRecordRdata>(rdata_));
+ OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
}
MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;
@@ -335,6 +495,27 @@ MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;
MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;
+// static
+bool MdnsRecord::IsValidConfig(const DomainName& name,
+ DnsType dns_type,
+ std::chrono::seconds ttl,
+ const Rdata& rdata) {
+ return !name.empty() && ttl.count() <= std::numeric_limits<uint32_t>::max() &&
+ ((dns_type == DnsType::kSRV &&
+ absl::holds_alternative<SrvRecordRdata>(rdata)) ||
+ (dns_type == DnsType::kA &&
+ absl::holds_alternative<ARecordRdata>(rdata)) ||
+ (dns_type == DnsType::kAAAA &&
+ absl::holds_alternative<AAAARecordRdata>(rdata)) ||
+ (dns_type == DnsType::kPTR &&
+ absl::holds_alternative<PtrRecordRdata>(rdata)) ||
+ (dns_type == DnsType::kTXT &&
+ absl::holds_alternative<TxtRecordRdata>(rdata)) ||
+ (dns_type == DnsType::kNSEC &&
+ absl::holds_alternative<NsecRecordRdata>(rdata)) ||
+ absl::holds_alternative<RawRecordRdata>(rdata));
+}
+
bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
record_type_ == rhs.record_type_ && ttl_ == rhs.ttl_ &&
@@ -345,12 +526,72 @@ bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
return !(*this == rhs);
}
+bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
+ // Returns the record which is lexicographically later. The determination of
+ // "lexicographically later" is performed by first comparing the record class,
+ // then the record type, then raw comparison of the binary content of the
+ // rdata without regard for meaning or structure.
+ if (dns_class() != rhs.dns_class()) {
+ return dns_class() > rhs.dns_class();
+ }
+
+ uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
+ uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
+ if (this_type != other_type) {
+ return this_type > other_type;
+ }
+
+ return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
+}
+
+bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
+ return rhs > *this;
+}
+
+bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
+ return !(*this > rhs);
+}
+
+bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
+ return !(*this < rhs);
+}
+
size_t MdnsRecord::MaxWireSize() const {
auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
// NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
}
+MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
+ Rdata rdata;
+ DnsType type;
+ std::chrono::seconds ttl;
+ if (address.IsV4()) {
+ type = DnsType::kA;
+ rdata = ARecordRdata(address);
+ ttl = kARecordTtl;
+ } else {
+ type = DnsType::kAAAA;
+ rdata = AAAARecordRdata(address);
+ ttl = kAAAARecordTtl;
+ }
+
+ return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
+ ttl, std::move(rdata));
+}
+
+// static
+ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
+ DnsType dns_type,
+ DnsClass dns_class,
+ ResponseType response_type) {
+ if (name.empty()) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
+}
+
MdnsQuestion::MdnsQuestion(DomainName name,
DnsType dns_type,
DnsClass dns_class,
@@ -376,6 +617,26 @@ size_t MdnsQuestion::MaxWireSize() const {
return name_.MaxWireSize() + 4;
}
+// static
+ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
+ uint16_t id,
+ MessageType type,
+ std::vector<MdnsQuestion> questions,
+ std::vector<MdnsRecord> answers,
+ std::vector<MdnsRecord> authority_records,
+ std::vector<MdnsRecord> additional_records) {
+ if (questions.size() >= kMaxMessageFieldEntryCount ||
+ answers.size() >= kMaxMessageFieldEntryCount ||
+ authority_records.size() >= kMaxMessageFieldEntryCount ||
+ additional_records.size() >= kMaxMessageFieldEntryCount) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ return MdnsMessage(id, type, std::move(questions), std::move(answers),
+ std::move(authority_records),
+ std::move(additional_records));
+}
+
MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
: id_(id), type_(type) {}
@@ -391,10 +652,10 @@ MdnsMessage::MdnsMessage(uint16_t id,
answers_(std::move(answers)),
authority_records_(std::move(authority_records)),
additional_records_(std::move(additional_records)) {
- OSP_DCHECK(questions_.size() < std::numeric_limits<uint16_t>::max());
- OSP_DCHECK(answers_.size() < std::numeric_limits<uint16_t>::max());
- OSP_DCHECK(authority_records_.size() < std::numeric_limits<uint16_t>::max());
- OSP_DCHECK(additional_records_.size() < std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
+ OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
+ OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
+ OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
for (const MdnsQuestion& question : questions_) {
max_wire_size_ += question.MaxWireSize();
@@ -421,36 +682,62 @@ bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
return !(*this == rhs);
}
+bool MdnsMessage::IsProbeQuery() const {
+ // A message is a probe query if it contains records in the authority section
+ // which answer the question being asked.
+ if (questions().empty() || authority_records().empty()) {
+ return false;
+ }
+
+ for (const MdnsQuestion& question : questions_) {
+ for (const MdnsRecord& record : authority_records_) {
+ if (question.name() == record.name() &&
+ ((question.dns_type() == record.dns_type()) ||
+ (question.dns_type() == DnsType::kANY)) &&
+ ((question.dns_class() == record.dns_class()) ||
+ (question.dns_class() == DnsClass::kANY))) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
size_t MdnsMessage::MaxWireSize() const {
return max_wire_size_;
}
void MdnsMessage::AddQuestion(MdnsQuestion question) {
- OSP_DCHECK(questions_.size() < std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += question.MaxWireSize();
questions_.emplace_back(std::move(question));
}
void MdnsMessage::AddAnswer(MdnsRecord record) {
- OSP_DCHECK(answers_.size() < std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
answers_.emplace_back(std::move(record));
}
void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
- OSP_DCHECK(authority_records_.size() < std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
authority_records_.emplace_back(std::move(record));
}
void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
- OSP_DCHECK(additional_records_.size() < std::numeric_limits<uint16_t>::max());
+ OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
additional_records_.emplace_back(std::move(record));
}
+bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
+ return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
+}
+
uint16_t CreateMessageId() {
- static std::atomic<uint16_t> id(0);
+ static uint16_t id(0);
return id++;
}
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h
index d81d073e754..9cacace6cb9 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h
@@ -5,15 +5,19 @@
#ifndef DISCOVERY_MDNS_MDNS_RECORDS_H_
#define DISCOVERY_MDNS_MDNS_RECORDS_H_
-#include <chrono>
+#include <algorithm>
+#include <chrono> // NOLINT
#include <functional>
#include <initializer_list>
#include <string>
+#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/variant.h"
#include "discovery/mdns/public/mdns_constants.h"
+#include "platform/base/error.h"
+#include "platform/base/interface_info.h"
#include "platform/base/ip_address.h"
#include "util/logging.h"
@@ -29,15 +33,31 @@ class DomainName {
DomainName();
template <typename IteratorType>
- DomainName(IteratorType first, IteratorType last) {
- labels_.reserve(std::distance(first, last));
+ static ErrorOr<DomainName> TryCreate(IteratorType first, IteratorType last) {
+ std::vector<std::string> labels;
+ size_t max_wire_size = 1;
+ labels.reserve(std::distance(first, last));
for (IteratorType entry = first; entry != last; ++entry) {
- OSP_DCHECK(IsValidDomainLabel(*entry));
- labels_.emplace_back(*entry);
+ if (!IsValidDomainLabel(*entry)) {
+ return Error::Code::kParameterInvalid;
+ }
+ labels.emplace_back(*entry);
// Include the length byte in the size calculation.
- max_wire_size_ += entry->size() + 1;
+ max_wire_size += entry->size() + 1;
+ }
+
+ if (max_wire_size > kMaxDomainNameLength) {
+ return Error::Code::kIndexOutOfBounds;
+ } else {
+ return DomainName(std::move(labels), max_wire_size);
}
- OSP_DCHECK(max_wire_size_ <= kMaxDomainNameLength);
+ }
+
+ template <typename IteratorType>
+ DomainName(IteratorType first, IteratorType last) {
+ ErrorOr<DomainName> domain = TryCreate(first, last);
+ OSP_DCHECK(domain.is_value());
+ *this = std::move(domain.value());
}
explicit DomainName(std::vector<std::string> labels);
explicit DomainName(const std::vector<absl::string_view>& labels);
@@ -70,6 +90,8 @@ class DomainName {
}
private:
+ DomainName(std::vector<std::string> labels, size_t max_wire_size);
+
// max_wire_size_ starts at 1 for the terminating character length.
size_t max_wire_size_ = 1;
std::vector<std::string> labels_;
@@ -80,6 +102,8 @@ class DomainName {
// distinguish a raw record type that we do not know the identity of.
class RawRecordRdata {
public:
+ static ErrorOr<RawRecordRdata> TryCreate(std::vector<uint8_t> rdata);
+
RawRecordRdata();
explicit RawRecordRdata(std::vector<uint8_t> rdata);
RawRecordRdata(const uint8_t* begin, size_t size);
@@ -148,7 +172,8 @@ class SrvRecordRdata {
class ARecordRdata {
public:
ARecordRdata();
- explicit ARecordRdata(IPAddress ipv4_address);
+ explicit ARecordRdata(IPAddress ipv4_address,
+ NetworkInterfaceIndex interface_index = 0);
ARecordRdata(const ARecordRdata& other);
ARecordRdata(ARecordRdata&& other);
@@ -159,6 +184,7 @@ class ARecordRdata {
size_t MaxWireSize() const;
const IPAddress& ipv4_address() const { return ipv4_address_; }
+ NetworkInterfaceIndex interface_index() const { return interface_index_; }
template <typename H>
friend H AbslHashValue(H h, const ARecordRdata& rdata) {
@@ -167,6 +193,7 @@ class ARecordRdata {
private:
IPAddress ipv4_address_{0, 0, 0, 0};
+ NetworkInterfaceIndex interface_index_;
};
// AAAA Record format (http://www.ietf.org/rfc/rfc1035.txt):
@@ -174,7 +201,8 @@ class ARecordRdata {
class AAAARecordRdata {
public:
AAAARecordRdata();
- explicit AAAARecordRdata(IPAddress ipv6_address);
+ explicit AAAARecordRdata(IPAddress ipv6_address,
+ NetworkInterfaceIndex interface_index = 0);
AAAARecordRdata(const AAAARecordRdata& other);
AAAARecordRdata(AAAARecordRdata&& other);
@@ -185,6 +213,7 @@ class AAAARecordRdata {
size_t MaxWireSize() const;
const IPAddress& ipv6_address() const { return ipv6_address_; }
+ NetworkInterfaceIndex interface_index() const { return interface_index_; }
template <typename H>
friend H AbslHashValue(H h, const AAAARecordRdata& rdata) {
@@ -192,7 +221,9 @@ class AAAARecordRdata {
}
private:
- IPAddress ipv6_address_{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ IPAddress ipv6_address_{0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000};
+ NetworkInterfaceIndex interface_index_;
};
// PTR record format (http://www.ietf.org/rfc/rfc1035.txt):
@@ -230,6 +261,9 @@ class PtrRecordRdata {
class TxtRecordRdata {
public:
using Entry = std::vector<uint8_t>;
+
+ static ErrorOr<TxtRecordRdata> TryCreate(std::vector<Entry> texts);
+
TxtRecordRdata();
explicit TxtRecordRdata(std::vector<Entry> texts);
TxtRecordRdata(const TxtRecordRdata& other);
@@ -250,6 +284,8 @@ class TxtRecordRdata {
}
private:
+ TxtRecordRdata(std::vector<std::string> texts, size_t max_wire_size);
+
// max_wire_size_ is at least 3, uint16_t record length and at the
// minimum a NULL byte character string is present.
size_t max_wire_size_ = 3;
@@ -258,12 +294,71 @@ class TxtRecordRdata {
std::vector<std::string> texts_;
};
+// NSEC record format (https://tools.ietf.org/html/rfc4034#section-4).
+// In mDNS, this record type is used for representing negative responses to
+// queries.
+//
+// next_domain_name: The next domain to process. In mDNS, this value is expected
+// to match the record-level domain name in a negative response.
+//
+// An example of how the |types_| vector is serialized is as follows:
+// When encoding the following DNS types:
+// - A (value 1)
+// - MX (value 15)
+// - RRSIG (value 46)
+// - NSEC (value 47)
+// - TYPE1234 (value 1234)
+// The result would be:
+// 0x00 0x06 0x40 0x01 0x00 0x00 0x00 0x03
+// 0x04 0x1b 0x00 0x00 0x00 0x00 0x00 0x00
+// 0x00 0x00 0x00 0x00 0x00 0x00 0x00 0x00
+// 0x00 0x00 0x00 0x00 0x00 0x00 0x00 0x00
+// 0x00 0x00 0x00 0x00 0x20
+class NsecRecordRdata {
+ public:
+ NsecRecordRdata();
+
+ // Constructor that takes an arbitrary number of DnsType parameters.
+ // NOTE: If `types...` provide a valid set of parameters for an
+ // std::vector<DnsType> ctor call, this will compile. Do not use this ctor
+ // except to provide multiple DnsType parameters.
+ template <typename... Types>
+ NsecRecordRdata(DomainName next_domain_name, Types... types)
+ : NsecRecordRdata(std::move(next_domain_name),
+ std::vector<DnsType>{types...}) {}
+ NsecRecordRdata(DomainName next_domain_name, std::vector<DnsType> types);
+ NsecRecordRdata(const NsecRecordRdata& other);
+ NsecRecordRdata(NsecRecordRdata&& other);
+
+ NsecRecordRdata& operator=(const NsecRecordRdata& rhs);
+ NsecRecordRdata& operator=(NsecRecordRdata&& rhs);
+ bool operator==(const NsecRecordRdata& rhs) const;
+ bool operator!=(const NsecRecordRdata& rhs) const;
+
+ size_t MaxWireSize() const;
+
+ const DomainName& next_domain_name() const { return next_domain_name_; }
+ const std::vector<DnsType>& types() const { return types_; }
+ const std::vector<uint8_t>& encoded_types() const { return encoded_types_; }
+
+ template <typename H>
+ friend H AbslHashValue(H h, const NsecRecordRdata& rdata) {
+ return H::combine(std::move(h), rdata.types_, rdata.next_domain_name_);
+ }
+
+ private:
+ std::vector<uint8_t> encoded_types_;
+ std::vector<DnsType> types_;
+ DomainName next_domain_name_;
+};
+
using Rdata = absl::variant<RawRecordRdata,
SrvRecordRdata,
ARecordRdata,
AAAARecordRdata,
PtrRecordRdata,
- TxtRecordRdata>;
+ TxtRecordRdata,
+ NsecRecordRdata>;
// Resource record top level format (http://www.ietf.org/rfc/rfc1035.txt):
// name: the name of the node to which this resource record pertains.
@@ -276,6 +371,13 @@ class MdnsRecord {
public:
using ConstRef = std::reference_wrapper<const MdnsRecord>;
+ static ErrorOr<MdnsRecord> TryCreate(DomainName name,
+ DnsType dns_type,
+ DnsClass dns_class,
+ RecordType record_type,
+ std::chrono::seconds ttl,
+ Rdata rdata);
+
MdnsRecord();
MdnsRecord(DomainName name,
DnsType dns_type,
@@ -290,6 +392,10 @@ class MdnsRecord {
MdnsRecord& operator=(MdnsRecord&& rhs);
bool operator==(const MdnsRecord& other) const;
bool operator!=(const MdnsRecord& other) const;
+ bool operator<(const MdnsRecord& other) const;
+ bool operator>(const MdnsRecord& other) const;
+ bool operator<=(const MdnsRecord& other) const;
+ bool operator>=(const MdnsRecord& other) const;
size_t MaxWireSize() const;
const DomainName& name() const { return name_; }
@@ -307,6 +413,11 @@ class MdnsRecord {
}
private:
+ static bool IsValidConfig(const DomainName& name,
+ DnsType dns_type,
+ std::chrono::seconds ttl,
+ const Rdata& rdata);
+
DomainName name_;
DnsType dns_type_ = static_cast<DnsType>(0);
DnsClass dns_class_ = static_cast<DnsClass>(0);
@@ -317,12 +428,20 @@ class MdnsRecord {
Rdata rdata_;
};
+// Creates an A or AAAA record as appropriate for the provided parameters.
+MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address);
+
// Question top level format (http://www.ietf.org/rfc/rfc1035.txt):
// name: a domain name which identifies the target resource set.
// type: 2 bytes network-order RR TYPE code.
// class: 2 bytes network-order RR CLASS code.
class MdnsQuestion {
public:
+ static ErrorOr<MdnsQuestion> TryCreate(DomainName name,
+ DnsType dns_type,
+ DnsClass dns_class,
+ ResponseType response_type);
+
MdnsQuestion() = default;
MdnsQuestion(DomainName name,
DnsType dns_type,
@@ -365,6 +484,14 @@ class MdnsQuestion {
// query
class MdnsMessage {
public:
+ static ErrorOr<MdnsMessage> TryCreate(
+ uint16_t id,
+ MessageType type,
+ std::vector<MdnsQuestion> questions,
+ std::vector<MdnsRecord> answers,
+ std::vector<MdnsRecord> authority_records,
+ std::vector<MdnsRecord> additional_records);
+
MdnsMessage() = default;
// Constructs a message with ID, flags and empty question, answer, authority
// and additional record collections.
@@ -384,9 +511,23 @@ class MdnsMessage {
void AddAuthorityRecord(MdnsRecord record);
void AddAdditionalRecord(MdnsRecord record);
+ // Returns false if adding a new record would push the size of this message
+ // beyond kMaxMulticastMessageSize, and true otherwise.
+ bool CanAddRecord(const MdnsRecord& record);
+
+ // Sets the truncated bit (TC), as specified in RFC 1035 Section 4.1.1.
+ void set_truncated() { is_truncated_ = true; }
+
+ // Returns true if the provided message is an mDNS probe query as described in
+ // RFC 6762 section 8.1. Specifically, it examines whether any question in
+ // the 'questions' section is a query for which answers are present in the
+ // 'authority records' section of the same message.
+ bool IsProbeQuery() const;
+
size_t MaxWireSize() const;
uint16_t id() const { return id_; }
MessageType type() const { return type_; }
+ bool is_truncated() const { return is_truncated_; }
const std::vector<MdnsQuestion>& questions() const { return questions_; }
const std::vector<MdnsRecord>& answers() const { return answers_; }
const std::vector<MdnsRecord>& authority_records() const {
@@ -407,6 +548,7 @@ class MdnsMessage {
// The mDNS header is 12 bytes long
size_t max_wire_size_ = sizeof(Header);
uint16_t id_ = 0;
+ bool is_truncated_ = false;
MessageType type_ = MessageType::Query;
std::vector<MdnsQuestion> questions_;
std::vector<MdnsRecord> answers_;
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc
index 1c8680165ef..0fab37ded24 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc
@@ -215,38 +215,34 @@ TEST(MdnsARecordRdataTest, CopyAndMove) {
}
TEST(MdnsAAAARecordRdataTest, Construct) {
- constexpr uint8_t kIPv6AddressBytes1[] = {
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ constexpr uint16_t kIPv6AddressHextets1[] = {
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
};
- constexpr uint8_t kIPv6AddressBytes2[] = {
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ constexpr uint16_t kIPv6AddressHextets2[] = {
+ 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
- IPAddress address1(kIPv6AddressBytes1);
+ IPAddress address1(kIPv6AddressHextets1);
AAAARecordRdata rdata1;
EXPECT_EQ(rdata1.MaxWireSize(), UINT64_C(18));
EXPECT_EQ(rdata1.ipv6_address(), address1);
- IPAddress address2(kIPv6AddressBytes2);
+ IPAddress address2(kIPv6AddressHextets2);
AAAARecordRdata rdata2(address2);
EXPECT_EQ(rdata2.MaxWireSize(), UINT64_C(18));
EXPECT_EQ(rdata2.ipv6_address(), address2);
}
TEST(MdnsAAAARecordRdataTest, Compare) {
- constexpr uint8_t kIPv6AddressBytes1[] = {
- 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
- 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
+ constexpr uint16_t kIPv6AddressHextets1[] = {
+ 0x0001, 0x0203, 0x0405, 0x0607, 0x0809, 0x0A0B, 0x0C0D, 0x0E0F,
};
- constexpr uint8_t kIPv6AddressBytes2[] = {
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ constexpr uint16_t kIPv6AddressHextets2[] = {
+ 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
- IPAddress address1(kIPv6AddressBytes1);
- IPAddress address2(kIPv6AddressBytes2);
+ IPAddress address1(kIPv6AddressHextets1);
+ IPAddress address2(kIPv6AddressHextets2);
AAAARecordRdata rdata1(address1);
AAAARecordRdata rdata2(address1);
AAAARecordRdata rdata3(address2);
@@ -256,11 +252,10 @@ TEST(MdnsAAAARecordRdataTest, Compare) {
}
TEST(MdnsAAAARecordRdataTest, CopyAndMove) {
- constexpr uint8_t kIPv6AddressBytes[] = {
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ constexpr uint16_t kIPv6AddressHextets[] = {
+ 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
- TestCopyAndMove(AAAARecordRdata(IPAddress(kIPv6AddressBytes)));
+ TestCopyAndMove(AAAARecordRdata(IPAddress(kIPv6AddressHextets)));
}
TEST(MdnsPtrRecordRdataTest, Construct) {
@@ -311,6 +306,135 @@ TEST(MdnsTxtRecordRdataTest, CopyAndMove) {
TestCopyAndMove(MakeTxtRecord({"foo=1", "bar=2"}));
}
+TEST(MdnsNsecRecordRdataTest, Construct) {
+ const DomainName domain{"testing", "local"};
+ NsecRecordRdata rdata(domain);
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize());
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ rdata = NsecRecordRdata(domain, DnsType::kA);
+ // It takes 3 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsType kA = 1 (encoded in byte 1)
+ // So the full encoded version is:
+ // 00000000 00000001 01000000
+ // |window| | size | | 0-7 |
+ // For a total of 3 bytes.
+ EXPECT_EQ(rdata.encoded_types(), (std::vector<uint8_t>{0x00, 0x01, 0x40}));
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 3);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes kTXT = 16 (encoded in byte 3)
+ // So the full encoded version is:
+ // 00000000 00000011 00000000 00000000 10000000
+ // |window| | size | | 0-7 | | 8-15 | |16-23 |
+ // For a total of 5 bytes.
+ rdata = NsecRecordRdata(domain, DnsType::kTXT);
+ EXPECT_EQ(rdata.encoded_types(),
+ (std::vector<uint8_t>{0x00, 0x03, 0x00, 0x00, 0x80}));
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 5);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes kSRV = 33 (encoded in byte 5)
+ // So the full encoded version is:
+ // 00000000 00000101 00000000 00000000 00000000 00000000 01000000
+ // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 |
+ // For a total of 7 bytes.
+ rdata = NsecRecordRdata(domain, DnsType::kSRV);
+ EXPECT_EQ(rdata.encoded_types(),
+ (std::vector<uint8_t>{0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x40}));
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 7);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes kNSEC = 47
+ // So the full encoded version is:
+ // 00000000 00000110 00000000 00000000 00000000 00000000 0000000 00000001
+ // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 |
+ // For a total of 8 bytes.
+ rdata = NsecRecordRdata(domain, DnsType::kNSEC);
+ EXPECT_EQ(
+ rdata.encoded_types(),
+ (std::vector<uint8_t>{0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}));
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 8);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes kNSEC = 255
+ // So 32 bits are required for the bitfield, for a total of 34 bits.
+ rdata = NsecRecordRdata(domain, DnsType::kANY);
+ std::vector<uint8_t> results{0x00, 32};
+ for (int i = 1; i < 32; i++) {
+ results.push_back(0x00);
+ }
+ results.push_back(0x01);
+ EXPECT_EQ(rdata.encoded_types(), results);
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 34);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes have the following values:
+ // - kA = 1 (encoded in byte 1)
+ // kTXT = 16 (encoded in byte 3)
+ // - kSRV = 33 (encoded in byte 5)
+ // - kNSEC = 47 (encoded in 6 bytes)
+ // - The largest of these is 47, so 6 bytes are needed to encode this data.
+ // So the full encoded version is:
+ // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001
+ // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 |
+ // For a total of 8 bytes.
+ rdata = NsecRecordRdata(domain, DnsType::kA, DnsType::kTXT, DnsType::kSRV,
+ DnsType::kNSEC);
+ EXPECT_EQ(
+ rdata.encoded_types(),
+ (std::vector<uint8_t>{0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01}));
+ EXPECT_EQ(rdata.MaxWireSize(), domain.MaxWireSize() + 8);
+ EXPECT_EQ(rdata.next_domain_name(), domain);
+}
+
+TEST(MdnsNsecRecordRdataTest, Compare) {
+ const DomainName domain{"testing", "local"};
+ const NsecRecordRdata rdata1(domain, DnsType::kA, DnsType::kSRV);
+ const NsecRecordRdata rdata2(domain, DnsType::kSRV, DnsType::kA);
+ const NsecRecordRdata rdata3(domain, DnsType::kSRV, DnsType::kA,
+ DnsType::kAAAA);
+ const NsecRecordRdata rdata4(domain, DnsType::kSRV, DnsType::kAAAA);
+
+ // Ensure equal Rdata values are evaluated as equal.
+ EXPECT_EQ(rdata1, rdata1);
+ EXPECT_EQ(rdata1, rdata2);
+ EXPECT_EQ(rdata2, rdata1);
+
+ // Ensure different Rdata values are not.
+ EXPECT_NE(rdata1, rdata3);
+ EXPECT_NE(rdata1, rdata4);
+ EXPECT_NE(rdata3, rdata4);
+}
+
+TEST(MdnsNsecRecordRdataTest, CopyAndMove) {
+ TestCopyAndMove(NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA,
+ DnsType::kSRV));
+}
+
TEST(MdnsRecordTest, Construct) {
MdnsRecord record1;
EXPECT_EQ(record1.MaxWireSize(), UINT64_C(11));
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc
index ce3dd4be510..969010aa64c 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc
@@ -4,6 +4,9 @@
#include "discovery/mdns/mdns_responder.h"
+#include <utility>
+
+#include "discovery/mdns/mdns_probe_manager.h"
#include "discovery/mdns/mdns_publisher.h"
#include "discovery/mdns/mdns_querier.h"
#include "discovery/mdns/mdns_random.h"
@@ -13,27 +16,280 @@
namespace openscreen {
namespace discovery {
+namespace {
+
+const std::array<std::string, 3> kServiceEnumerationDomainLabels{
+ "_services", "_dns-sd", "_udp"};
+
+enum AddResult { kNonePresent = 0, kAdded, kAlreadyKnown };
+
+std::chrono::seconds GetTtlForRecordType(DnsType type) {
+ switch (type) {
+ case DnsType::kA:
+ return kARecordTtl;
+ case DnsType::kAAAA:
+ return kAAAARecordTtl;
+ case DnsType::kPTR:
+ return kPtrRecordTtl;
+ case DnsType::kSRV:
+ return kSrvRecordTtl;
+ case DnsType::kTXT:
+ return kTXTRecordTtl;
+ case DnsType::kANY:
+ // If no records are present, re-querying should happen at the minimum
+ // of any record that might be retrieved at that time.
+ return kSrvRecordTtl;
+ default:
+ OSP_NOTREACHED();
+ return std::chrono::seconds{0};
+ }
+}
+
+MdnsRecord CreateNsecRecord(DomainName target_name,
+ DnsType target_type,
+ DnsClass target_class) {
+ auto rdata = NsecRecordRdata(target_name, target_type);
+ std::chrono::seconds ttl = GetTtlForRecordType(target_type);
+ return MdnsRecord(std::move(target_name), DnsType::kNSEC, target_class,
+ RecordType::kUnique, ttl, std::move(rdata));
+}
+
+inline bool IsValidAdditionalRecordType(DnsType type) {
+ return type == DnsType::kSRV || type == DnsType::kTXT ||
+ type == DnsType::kA || type == DnsType::kAAAA;
+}
+
+AddResult AddRecords(std::function<void(MdnsRecord record)> add_func,
+ MdnsResponder::RecordHandler* record_handler,
+ const DomainName& domain,
+ const std::vector<MdnsRecord>& known_answers,
+ DnsType type,
+ DnsClass clazz,
+ bool add_negative_on_unknown) {
+ auto records = record_handler->GetRecords(domain, type, clazz);
+ if (records.empty()) {
+ if (add_negative_on_unknown) {
+ // TODO(rwkeane): Aggregate all NSEC records together into a single NSEC
+ // record to reduce traffic.
+ add_func(CreateNsecRecord(domain, type, clazz));
+ }
+ return AddResult::kNonePresent;
+ } else {
+ bool added_any_records = false;
+ for (auto it = records.begin(); it != records.end(); it++) {
+ if (std::find(known_answers.begin(), known_answers.end(), *it) ==
+ known_answers.end()) {
+ added_any_records = true;
+ add_func(std::move(*it));
+ }
+ }
+ return added_any_records ? AddResult::kAdded : AddResult::kAlreadyKnown;
+ }
+}
+
+inline AddResult AddAdditionalRecords(
+ MdnsMessage* message,
+ MdnsResponder::RecordHandler* record_handler,
+ const DomainName& domain,
+ const std::vector<MdnsRecord>& known_answers,
+ DnsType type,
+ DnsClass clazz,
+ bool add_negative_on_unknown) {
+ OSP_DCHECK(IsValidAdditionalRecordType(type));
+
+ auto add_func = [message](MdnsRecord record) {
+ message->AddAdditionalRecord(std::move(record));
+ };
+ return AddRecords(std::move(add_func), record_handler, domain, known_answers,
+ type, clazz, add_negative_on_unknown);
+}
+
+inline AddResult AddResponseRecords(
+ MdnsMessage* message,
+ MdnsResponder::RecordHandler* record_handler,
+ const DomainName& domain,
+ const std::vector<MdnsRecord>& known_answers,
+ DnsType type,
+ DnsClass clazz,
+ bool add_negative_on_unknown) {
+ auto add_func = [message](MdnsRecord record) {
+ message->AddAnswer(std::move(record));
+ };
+ return AddRecords(std::move(add_func), record_handler, domain, known_answers,
+ type, clazz, add_negative_on_unknown);
+}
+
+void ApplyQueryResults(MdnsMessage* message,
+ MdnsResponder::RecordHandler* record_handler,
+ const DomainName& domain,
+ const std::vector<MdnsRecord>& known_answers,
+ DnsType type,
+ DnsClass clazz,
+ bool is_exclusive_owner) {
+ OSP_DCHECK(type != DnsType::kNSEC);
+
+ // All records matching the provided query which have been published by this
+ // host should be added to the response message per RFC 6762 section 6. If
+ // this host is the exclusive owner of the queried domain name, then a
+ // negative response NSEC record should be added in the case where the queried
+ // record does not exist, per RFC 6762 section 6.1.
+ if (AddResponseRecords(message, record_handler, domain, known_answers, type,
+ clazz, is_exclusive_owner) != AddResult::kAdded) {
+ return;
+ }
+
+ // Per RFC 6763 section 12.1, when querying for a PTR record, all SRV records
+ // and TXT records named in the PTR record's rdata should be added to the
+ // messages additional records, as well as the address records of types A and
+ // AAAA associated with the added SRV records. Per RFC 6762 section 6.1,
+ // records with names matching those of reverse address mappings for PTR
+ // records may be added as negative response NSEC records if they do not
+ // exist.
+ if (type == DnsType::kPTR) {
+ // Add all SRV and TXT records to the additional records section.
+ for (const MdnsRecord& record : message->answers()) {
+ OSP_DCHECK(record.dns_type() == DnsType::kPTR);
+
+ const DomainName& target =
+ absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kSRV, clazz, true);
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kTXT, clazz, true);
+ }
+
+ // Add A and AAAA records associated with an added SRV record to the
+ // additional records section.
+ const int max = message->additional_records().size();
+ for (int i = 0; i < max; i++) {
+ if (message->additional_records()[i].dns_type() != DnsType::kSRV) {
+ continue;
+ }
+
+ {
+ const MdnsRecord& srv_record = message->additional_records()[i];
+ const DomainName& target =
+ absl::get<SrvRecordRdata>(srv_record.rdata()).target();
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kA, clazz, target == domain);
+ }
+
+ // Must re-calculate the |srv_record|, |target| refs in case a resize of
+ // the additional_records() vector has invalidated them.
+ {
+ const MdnsRecord& srv_record = message->additional_records()[i];
+ const DomainName& target =
+ absl::get<SrvRecordRdata>(srv_record.rdata()).target();
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kAAAA, clazz, target == domain);
+ }
+ }
+ }
+
+ // Per RFC 6763 section 12.2, when querying for an SRV record, all address
+ // records of type A and AAAA should be added to the additional records
+ // section. Per RFC 6762 section 6.1, if these records are not present and
+ // their name and class match that which is being queried for, a negative
+ // response NSEC record may be added to show their non-existence.
+ else if (type == DnsType::kSRV) {
+ for (const auto& srv_record : message->answers()) {
+ OSP_DCHECK(srv_record.dns_type() == DnsType::kSRV);
+
+ const DomainName& target =
+ absl::get<SrvRecordRdata>(srv_record.rdata()).target();
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kA, clazz, target == domain);
+ AddAdditionalRecords(message, record_handler, target, known_answers,
+ DnsType::kAAAA, clazz, target == domain);
+ }
+ }
+
+ // Per RFC 6762 section 6.2, when querying for an address record of type A or
+ // AAAA, the record of the opposite type should be added to the additional
+ // records section if present. Else, a negative response NSEC record should be
+ // added to show its non-existence.
+ else if (type == DnsType::kA) {
+ AddAdditionalRecords(message, record_handler, domain, known_answers,
+ DnsType::kAAAA, clazz, true);
+ } else if (type == DnsType::kAAAA) {
+ AddAdditionalRecords(message, record_handler, domain, known_answers,
+ DnsType::kA, clazz, true);
+ }
+
+ // The remaining supported records types are TXT, NSEC, and ANY. RFCs 6762 and
+ // 6763 do not recommend sending any records in the additional records section
+ // for queries of types TXT or ANY, and NSEC records are not supported for
+ // queries.
+}
+
+// Determines if the provided query is a type enumeration query as described in
+// RFC 6763 section 9.
+bool IsServiceTypeEnumerationQuery(const MdnsQuestion& question) {
+ if (question.dns_type() != DnsType::kPTR) {
+ return false;
+ }
+
+ if (question.name().labels().size() <
+ kServiceEnumerationDomainLabels.size()) {
+ return false;
+ }
+
+ const auto question_it = question.name().labels().begin();
+ return std::equal(question_it,
+ question_it + kServiceEnumerationDomainLabels.size(),
+ kServiceEnumerationDomainLabels.begin(),
+ kServiceEnumerationDomainLabels.end());
+}
+
+// Creates the expected response to a type enumeration query as described in RFC
+// 6763 section 9.
+void ApplyServiceTypeEnumerationResults(
+ MdnsMessage* message,
+ MdnsResponder::RecordHandler* record_handler,
+ const DomainName& name,
+ DnsClass clazz) {
+ if (name.labels().size() < kServiceEnumerationDomainLabels.size()) {
+ return;
+ }
+
+ std::vector<MdnsRecord::ConstRef> records =
+ record_handler->GetPtrRecords(clazz);
+
+ // skip "_services._dns-sd._udp." which was already checked for in above
+ // method and just use the domain.
+ const auto domain_it =
+ name.labels().begin() + kServiceEnumerationDomainLabels.size();
+ for (const MdnsRecord& record : records) {
+ // Skip the 2 label service name in the PTR record's name.
+ const auto record_it = record.name().labels().begin() + 2;
+ if (std::equal(domain_it, name.labels().end(), record_it,
+ record.name().labels().end())) {
+ message->AddAnswer(MdnsRecord(name, DnsType::kPTR, record.dns_class(),
+ RecordType::kShared, record.ttl(),
+ PtrRecordRdata(record.name())));
+ }
+ }
+}
+
+} // namespace
MdnsResponder::MdnsResponder(RecordHandler* record_handler,
+ MdnsProbeManager* ownership_handler,
MdnsSender* sender,
MdnsReceiver* receiver,
- MdnsQuerier* querier,
- platform::TaskRunner* task_runner,
- platform::ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
MdnsRandom* random_delay)
: record_handler_(record_handler),
+ ownership_handler_(ownership_handler),
sender_(sender),
receiver_(receiver),
- querier_(querier),
task_runner_(task_runner),
- now_function_(now_function),
random_delay_(random_delay) {
OSP_DCHECK(record_handler_);
+ OSP_DCHECK(ownership_handler_);
OSP_DCHECK(sender_);
OSP_DCHECK(receiver_);
- OSP_DCHECK(querier_);
OSP_DCHECK(task_runner_);
- OSP_DCHECK(now_function_);
OSP_DCHECK(random_delay_);
auto func = [this](const MdnsMessage& message, const IPEndpoint& src) {
@@ -46,12 +302,117 @@ MdnsResponder::~MdnsResponder() {
receiver_->SetQueryCallback(nullptr);
}
+MdnsResponder::RecordHandler::~RecordHandler() = default;
+
void MdnsResponder::OnMessageReceived(const MdnsMessage& message,
const IPEndpoint& src) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(message.type() == MessageType::Query);
- // TODO(rwkeane): implement responding to the query
+ if (message.questions().empty()) {
+ // TODO(rwkeane): Support multi-packet known answer suppression.
+ return;
+ }
+
+ if (message.IsProbeQuery()) {
+ ownership_handler_->RespondToProbeQuery(message, src);
+ return;
+ }
+
+ OSP_DVLOG << "Received mDNS Query with " << message.questions().size()
+ << " questions. Processing...";
+
+ const std::vector<MdnsRecord>& known_answers = message.answers();
+
+ for (const auto& question : message.questions()) {
+ OSP_DVLOG << "\tProcessing mDNS Query for domain: '"
+ << question.name().ToString() << "', type: '"
+ << question.dns_type() << "'";
+
+ // NSEC records should not be queried for.
+ if (question.dns_type() == DnsType::kNSEC) {
+ continue;
+ }
+
+ // Only respond to queries for which one of the following is true:
+ // - This host is the sole owner of that domain.
+ // - A record corresponding to this question has been published.
+ // - The query is a service enumeration query.
+ const bool is_service_enumeration = IsServiceTypeEnumerationQuery(question);
+ const bool is_exclusive_owner =
+ ownership_handler_->IsDomainClaimed(question.name());
+ if (!is_service_enumeration && !is_exclusive_owner &&
+ !record_handler_->HasRecords(question.name(), question.dns_type(),
+ question.dns_class())) {
+ OSP_DVLOG << "\tmDNS Query processed and no relevant records found!";
+ continue;
+ } else if (is_service_enumeration) {
+ OSP_DVLOG << "\tmDNS Query is for service type enumeration!";
+ }
+
+ // Relevant records are published, so send them out using the response type
+ // dictated in the question.
+ std::function<void(const MdnsMessage&)> send_response;
+ if (question.response_type() == ResponseType::kMulticast) {
+ send_response = [this](const MdnsMessage& message) {
+ sender_->SendMulticast(message);
+ };
+ } else {
+ OSP_DCHECK(question.response_type() == ResponseType::kUnicast);
+ send_response = [this, src](const MdnsMessage& message) {
+ sender_->SendMessage(message, src);
+ };
+ }
+
+ // If this host is the exclusive owner, respond immediately. Else, there may
+ // be network contention if all hosts respond simultaneously, so delay the
+ // response as dictated by RFC 6762.
+ if (is_exclusive_owner) {
+ SendResponse(question, known_answers, send_response, is_exclusive_owner);
+ } else {
+ const auto delay = random_delay_->GetSharedRecordResponseDelay();
+ std::function<void()> response = [this, question, known_answers,
+ send_response, is_exclusive_owner]() {
+ SendResponse(question, known_answers, send_response,
+ is_exclusive_owner);
+ };
+ task_runner_->PostTaskWithDelay(response, delay);
+ }
+ }
+}
+
+void MdnsResponder::SendResponse(
+ const MdnsQuestion& question,
+ const std::vector<MdnsRecord>& known_answers,
+ std::function<void(const MdnsMessage&)> send_response,
+ bool is_exclusive_owner) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ MdnsMessage message(CreateMessageId(), MessageType::Response);
+
+ if (IsServiceTypeEnumerationQuery(question)) {
+ // This is a special case defined in RFC 6763 section 9, so handle it
+ // separately.
+ ApplyServiceTypeEnumerationResults(&message, record_handler_,
+ question.name(), question.dns_class());
+ } else {
+ // NOTE: The exclusive ownership of this record cannot change before this
+ // method is called. Exclusive ownership cannot be gained for a record which
+ // has previously been published, and if this host is the exclusive owner
+ // then this method will have been called without any delay on the task
+ // runner
+ ApplyQueryResults(&message, record_handler_, question.name(), known_answers,
+ question.dns_type(), question.dns_class(),
+ is_exclusive_owner);
+ }
+
+ // Send the response only if it contains answers to the query.
+ if (!message.answers().empty()) {
+ OSP_DVLOG << "\tmDNS Query processed and response sent!";
+ send_response(message);
+ } else {
+ OSP_DVLOG << "\tmDNS Query processed and no response sent!";
+ }
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h
index cde4518e3dc..8ac7d1f3d4c 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.h
@@ -12,27 +12,35 @@
#include "platform/base/macros.h"
namespace openscreen {
-struct IPEndpoint;
-namespace platform {
+struct IPEndpoint;
class TaskRunner;
-} // namespace platform
namespace discovery {
class MdnsMessage;
+class MdnsProbeManager;
class MdnsRandom;
class MdnsReceiver;
class MdnsRecordChangedCallback;
class MdnsSender;
class MdnsQuerier;
+// This class is responsible for responding to any incoming mDNS Queries
+// received via the OnMessageReceived() method. When responding, the generated
+// MdnsMessage will contain the requested record(s) in the answers section, or
+// an NSEC record to specify that the requested record was not found in the case
+// of a query with DnsType aside from ANY. In the case where records are found,
+// the additional records field may be populated with additional records, as
+// specified in RFCs 6762 and 6763.
+// TODO(rwkeane): Handle known answers, and waiting when the truncated (TC) bit
+// is set.
class MdnsResponder {
public:
// Class to handle querying for existing records.
class RecordHandler {
- // Returns whether the provided name is exclusively owned by this endpoint.
- virtual bool IsExclusiveOwner(const DomainName& name) = 0;
+ public:
+ virtual ~RecordHandler();
// Returns whether this service has one or more records matching the
// provided name, type, and class.
@@ -45,17 +53,18 @@ class MdnsResponder {
virtual std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
DnsType type,
DnsClass clazz) = 0;
+
+ // Enumerates all PTR records owned by this service.
+ virtual std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) = 0;
};
- // |record_handler|, |sender|, |receiver|, |querier|, |task_runner|, and
- // |random_delay| are expected to persist for the duration of this instance's
- // lifetime.
+ // |record_handler|, |sender|, |receiver|, |task_runner|, and |random_delay|
+ // are expected to persist for the duration of this instance's lifetime.
MdnsResponder(RecordHandler* record_handler,
+ MdnsProbeManager* ownership_handler,
MdnsSender* sender,
MdnsReceiver* receiver,
- MdnsQuerier* querier,
- platform::TaskRunner* task_runner,
- platform::ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
MdnsRandom* random_delay);
~MdnsResponder();
@@ -64,13 +73,19 @@ class MdnsResponder {
private:
void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src);
+ void SendResponse(const MdnsQuestion& question,
+ const std::vector<MdnsRecord>& known_answers,
+ std::function<void(const MdnsMessage&)> send_response,
+ bool is_exclusive_owner);
+
RecordHandler* const record_handler_;
+ MdnsProbeManager* const ownership_handler_;
MdnsSender* const sender_;
MdnsReceiver* const receiver_;
- MdnsQuerier* const querier_;
- platform::TaskRunner* const task_runner_;
- const platform::ClockNowFunctionPtr now_function_;
+ TaskRunner* const task_runner_;
MdnsRandom* const random_delay_;
+
+ friend class MdnsResponderTest;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc
new file mode 100644
index 00000000000..037a71fec94
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder_unittest.cc
@@ -0,0 +1,762 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/mdns_responder.h"
+
+#include <utility>
+
+#include "discovery/mdns/mdns_probe_manager.h"
+#include "discovery/mdns/mdns_random.h"
+#include "discovery/mdns/mdns_receiver.h"
+#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "platform/test/fake_clock.h"
+#include "platform/test/fake_task_runner.h"
+#include "platform/test/fake_udp_socket.h"
+
+namespace openscreen {
+namespace discovery {
+namespace {
+
+constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000);
+
+bool ContainsRecordType(const std::vector<MdnsRecord>& records, DnsType type) {
+ return std::find_if(records.begin(), records.end(),
+ [type](const MdnsRecord& record) {
+ return record.dns_type() == type;
+ }) != records.end();
+}
+
+void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) {
+ ASSERT_EQ(message.answers().size(), size_t{1});
+ const MdnsRecord record = message.answers()[0];
+
+ ASSERT_EQ(record.dns_type(), DnsType::kNSEC);
+ const NsecRecordRdata& rdata = absl::get<NsecRecordRdata>(record.rdata());
+
+ ASSERT_EQ(rdata.types().size(), size_t{1});
+ EXPECT_EQ(rdata.types()[0], type);
+}
+
+void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) {
+ ASSERT_EQ(record.dns_type(), DnsType::kPTR);
+ const PtrRecordRdata& rdata = absl::get<PtrRecordRdata>(record.rdata());
+
+ EXPECT_EQ(rdata.ptr_domain(), domain);
+}
+
+void ExpectContainsNsecRecordType(const std::vector<MdnsRecord>& records,
+ DnsType type) {
+ auto it = std::find_if(
+ records.begin(), records.end(), [type](const MdnsRecord& record) {
+ if (record.dns_type() != DnsType::kNSEC) {
+ return false;
+ }
+
+ const NsecRecordRdata& rdata =
+ absl::get<NsecRecordRdata>(record.rdata());
+ return rdata.types().size() == 1 && rdata.types()[0] == type;
+ });
+ EXPECT_TRUE(it != records.end());
+}
+
+} // namespace
+
+using testing::_;
+using testing::Args;
+using testing::Invoke;
+using testing::Return;
+using testing::StrictMock;
+
+class MockRecordHandler : public MdnsResponder::RecordHandler {
+ public:
+ void AddRecord(MdnsRecord record) { records_.push_back(record); }
+
+ MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass));
+
+ std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
+ DnsType type,
+ DnsClass clazz) override {
+ std::vector<MdnsRecord::ConstRef> records;
+ for (const auto& record : records_) {
+ if (type == DnsType::kANY || record.dns_type() == type) {
+ records.push_back(record);
+ }
+ }
+
+ return records;
+ }
+
+ std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override {
+ std::vector<MdnsRecord::ConstRef> records;
+ for (const auto& record : records_) {
+ if (record.dns_type() == DnsType::kPTR) {
+ records.push_back(record);
+ }
+ }
+
+ return records;
+ }
+
+ private:
+ std::vector<MdnsRecord> records_;
+};
+
+class MockMdnsSender : public MdnsSender {
+ public:
+ explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
+
+ MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
+ MOCK_METHOD2(SendMessage,
+ Error(const MdnsMessage& message, const IPEndpoint& endpoint));
+};
+
+class MockProbeManager : public MdnsProbeManager {
+ public:
+ MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
+ MOCK_METHOD2(RespondToProbeQuery,
+ void(const MdnsMessage&, const IPEndpoint&));
+};
+
+class MdnsResponderTest : public testing::Test {
+ public:
+ MdnsResponderTest()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ socket_(&task_runner_),
+ sender_(&socket_),
+ responder_(&record_handler_,
+ &probe_manager_,
+ &sender_,
+ &receiver_,
+ &task_runner_,
+ &random_) {}
+
+ protected:
+ MdnsRecord GetFakePtrRecord(const DomainName& target) {
+ DomainName name(++target.labels().begin(), target.labels().end());
+ PtrRecordRdata rdata(target);
+ return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN,
+ RecordType::kUnique, std::chrono::seconds(0), rdata);
+ }
+
+ MdnsRecord GetFakeSrvRecord(const DomainName& name) {
+ SrvRecordRdata rdata(0, 0, 80, name);
+ return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
+ std::chrono::seconds(0), rdata);
+ }
+
+ MdnsRecord GetFakeTxtRecord(const DomainName& name) {
+ TxtRecordRdata rdata;
+ return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
+ std::chrono::seconds(0), rdata);
+ }
+
+ MdnsRecord GetFakeARecord(const DomainName& name) {
+ ARecordRdata rdata(IPAddress(192, 168, 0, 0));
+ return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique,
+ std::chrono::seconds(0), rdata);
+ }
+
+ MdnsRecord GetFakeAAAARecord(const DomainName& name) {
+ AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8));
+ return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
+ std::chrono::seconds(0), rdata);
+ }
+
+ void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) {
+ responder_.OnMessageReceived(message, src);
+ }
+
+ void QueryForRecordTypeWhenNonePresent(DnsType type) {
+ MdnsQuestion question(domain_, type, DnsClass::kANY,
+ ResponseType::kMulticast);
+ MdnsMessage message(0, MessageType::Query);
+ message.AddQuestion(question);
+
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([type](const MdnsMessage& msg) -> Error {
+ CheckSingleNsecRecordType(msg, type);
+ return Error::None();
+ });
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ OnMessageReceived(message, endpoint_);
+ }
+
+ MdnsMessage CreateMulticastMdnsQuery(DnsType type) {
+ MdnsQuestion question(domain_, type, DnsClass::kANY,
+ ResponseType::kMulticast);
+ MdnsMessage message(0, MessageType::Query);
+ message.AddQuestion(std::move(question));
+
+ return message;
+ }
+
+ MdnsMessage CreateTypeEnumerationQuery() {
+ MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR,
+ DnsClass::kANY, ResponseType::kMulticast);
+ MdnsMessage message(0, MessageType::Query);
+ message.AddQuestion(std::move(question));
+
+ return message;
+ }
+
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ FakeUdpSocket socket_;
+ StrictMock<MockMdnsSender> sender_;
+ StrictMock<MockRecordHandler> record_handler_;
+ StrictMock<MockProbeManager> probe_manager_;
+ MdnsReceiver receiver_;
+ MdnsRandom random_;
+ MdnsResponder responder_;
+
+ DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
+ DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"};
+ IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80};
+};
+
+// Validate that when records may be sent from multiple receivers, the broadcast
+// is delayed and it is not delayed otherwise.
+TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ OnMessageReceived(message, endpoint_);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ testing::Mock::VerifyAndClearExpectations(&record_handler_);
+ testing::Mock::VerifyAndClearExpectations(&probe_manager_);
+
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
+ clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
+}
+
+TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
+ OnMessageReceived(message, endpoint_);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ testing::Mock::VerifyAndClearExpectations(&record_handler_);
+ testing::Mock::VerifyAndClearExpectations(&probe_manager_);
+
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
+}
+
+TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+ MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY,
+ ResponseType::kMulticast);
+ message.AddQuestion(std::move(question2));
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_))
+ .WillOnce(Return(true))
+ .WillOnce(Return(false));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ OnMessageReceived(message, endpoint_);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+ testing::Mock::VerifyAndClearExpectations(&record_handler_);
+ testing::Mock::VerifyAndClearExpectations(&probe_manager_);
+
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
+}
+
+// Validate that the correct messaging scheme (unicast vs multicast) is used.
+TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) {
+ MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY,
+ ResponseType::kUnicast);
+ MdnsMessage message(0, MessageType::Query);
+ message.AddQuestion(question);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1);
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ OnMessageReceived(message, endpoint_);
+}
+
+// Validate that records are added as expected based on the query type, and that
+// additional records are populated as specified in RFC 6762 and 6763.
+TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{4});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT));
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
+ EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR));
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, PtrQueryResultsApplied) {
+ DomainName ptr_domain{"_googlecast", "_tcp", "local"};
+ MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY,
+ ResponseType::kMulticast);
+ MdnsMessage message(0, MessageType::Query);
+ message.AddQuestion(question);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{4});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
+
+ const auto& records = message.additional_records();
+ EXPECT_EQ(records.size(), size_t{4});
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV));
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT));
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, SrvQueryResultsApplied) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{2});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
+
+ const auto& records = message.additional_records();
+ EXPECT_EQ(records.size(), size_t{2});
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, AQueryResultsApplied) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{1});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
+
+ const auto& records = message.additional_records();
+ EXPECT_EQ(records.size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{1});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
+
+ const auto& records = message.additional_records();
+ EXPECT_EQ(records.size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
+ EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
+ MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
+ message.AddAnswer(aaaa_record);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(aaaa_record);
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
+ MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
+ message.AddAnswer(aaaa_record);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ record_handler_.AddRecord(aaaa_record);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+ EXPECT_EQ(message.additional_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+// Validate NSEC records are used correctly.
+TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) {
+ QueryForRecordTypeWhenNonePresent(DnsType::kANY);
+ QueryForRecordTypeWhenNonePresent(DnsType::kSRV);
+ QueryForRecordTypeWhenNonePresent(DnsType::kTXT);
+ QueryForRecordTypeWhenNonePresent(DnsType::kA);
+ QueryForRecordTypeWhenNonePresent(DnsType::kAAAA);
+}
+
+TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeAAAARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{1});
+ ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{1});
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kAAAA);
+
+ return Error::None();
+ });
+
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{2});
+ ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kAAAA);
+
+ return Error::None();
+ });
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{2});
+ EXPECT_TRUE(
+ ContainsRecordType(message.additional_records(), DnsType::kA));
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kAAAA);
+
+ return Error::None();
+ });
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{2});
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kTXT);
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kSRV);
+
+ return Error::None();
+ });
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{2});
+ EXPECT_TRUE(
+ ContainsRecordType(message.additional_records(), DnsType::kTXT));
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kSRV);
+
+ return Error::None();
+ });
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) {
+ MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ record_handler_.AddRecord(GetFakePtrRecord(domain_));
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
+
+ EXPECT_EQ(message.additional_records().size(), size_t{4});
+ EXPECT_TRUE(
+ ContainsRecordType(message.additional_records(), DnsType::kSRV));
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kTXT);
+ ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
+ ExpectContainsNsecRecordType(message.additional_records(),
+ DnsType::kAAAA);
+
+ return Error::None();
+ });
+ OnMessageReceived(message, endpoint_);
+}
+
+TEST_F(MdnsResponderTest, EnumerateAllQuery) {
+ MdnsMessage message = CreateTypeEnumerationQuery();
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ const auto ptr = GetFakePtrRecord(domain_);
+ record_handler_.AddRecord(ptr);
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ OnMessageReceived(message, endpoint_);
+
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce([this, &ptr](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{0});
+ EXPECT_EQ(message.authority_records().size(), size_t{0});
+
+ EXPECT_EQ(message.answers().size(), size_t{1});
+ EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
+ EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_);
+ CheckPtrDomain(message.answers()[0], ptr.name());
+ return Error::None();
+ });
+ clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
+}
+
+TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) {
+ MdnsMessage message = CreateTypeEnumerationQuery();
+
+ EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
+ EXPECT_CALL(record_handler_, HasRecords(_, _, _))
+ .WillRepeatedly(Return(true));
+ const auto ptr = GetFakePtrRecord(domain_);
+ record_handler_.AddRecord(GetFakeSrvRecord(domain_));
+ record_handler_.AddRecord(GetFakeTxtRecord(domain_));
+ record_handler_.AddRecord(GetFakeARecord(domain_));
+ OnMessageReceived(message, endpoint_);
+ clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc
index ac6305feba4..607c4227a37 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.cc
@@ -4,39 +4,28 @@
#include "discovery/mdns/mdns_sender.h"
+#include <iostream>
+
#include "discovery/mdns/mdns_writer.h"
+#include "platform/api/udp_socket.h"
namespace openscreen {
namespace discovery {
-namespace {
-
-const IPEndpoint& GetIPv6MdnsMulticastEndpoint() {
- static IPEndpoint endpoint{.address = IPAddress(kDefaultMulticastGroupIPv6),
- .port = kDefaultMulticastPort};
- return endpoint;
-}
-
-const IPEndpoint& GetIPv4MdnsMulticastEndpoint() {
- static IPEndpoint endpoint{.address = IPAddress(kDefaultMulticastGroupIPv4),
- .port = kDefaultMulticastPort};
- return endpoint;
-}
-
-} // namespace
-
MdnsSender::MdnsSender(UdpSocket* socket) : socket_(socket) {
OSP_DCHECK(socket_ != nullptr);
}
+MdnsSender::~MdnsSender() = default;
+
Error MdnsSender::SendMulticast(const MdnsMessage& message) {
const IPEndpoint& endpoint = socket_->IsIPv6()
- ? GetIPv6MdnsMulticastEndpoint()
- : GetIPv4MdnsMulticastEndpoint();
- return SendUnicast(message, endpoint);
+ ? kDefaultMulticastGroupIPv6Endpoint
+ : kDefaultMulticastGroupIPv4Endpoint;
+ return SendMessage(message, endpoint);
}
-Error MdnsSender::SendUnicast(const MdnsMessage& message,
+Error MdnsSender::SendMessage(const MdnsMessage& message,
const IPEndpoint& endpoint) {
// Always try to write the message into the buffer even if MaxWireSize is
// greater than maximum message size. Domain name compression might reduce the
@@ -52,5 +41,9 @@ Error MdnsSender::SendUnicast(const MdnsMessage& message,
return Error::Code::kNone;
}
+void MdnsSender::OnSendError(UdpSocket* socket, Error error) {
+ OSP_LOG_ERROR << "Error sending packet";
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h
index b51eb12228e..97356ced617 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender.h
@@ -16,19 +16,21 @@ class MdnsMessage;
class MdnsSender {
public:
- using UdpSocket = openscreen::platform::UdpSocket;
-
// MdnsSender does not own |socket| and expects that its lifetime exceeds the
// lifetime of MdnsSender.
explicit MdnsSender(UdpSocket* socket);
MdnsSender(const MdnsSender& other) = delete;
MdnsSender(MdnsSender&& other) noexcept = delete;
+ virtual ~MdnsSender();
+
MdnsSender& operator=(const MdnsSender& other) = delete;
MdnsSender& operator=(MdnsSender&& other) noexcept = delete;
- ~MdnsSender() = default;
- Error SendMulticast(const MdnsMessage& message);
- Error SendUnicast(const MdnsMessage& message, const IPEndpoint& endpoint);
+ virtual Error SendMulticast(const MdnsMessage& message);
+ virtual Error SendMessage(const MdnsMessage& message,
+ const IPEndpoint& endpoint);
+
+ void OnSendError(UdpSocket* socket, Error error);
private:
UdpSocket* const socket_;
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc
index da67ef02bea..b125b46fd5a 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_sender_unittest.cc
@@ -4,32 +4,31 @@
#include "discovery/mdns/mdns_sender.h"
+#include <memory>
+#include <vector>
+
#include "discovery/mdns/mdns_records.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/test/fake_udp_socket.h"
+#include "platform/test/mock_udp_socket.h"
namespace openscreen {
namespace discovery {
-using openscreen::platform::FakeUdpSocket;
using testing::_;
using testing::Args;
using testing::Return;
+using testing::StrictMock;
+using testing::WithArgs;
namespace {
-MATCHER_P(
- VoidPointerMatchesBytes,
- expected_data,
- "Matches data at the pointer against the provided C-style byte array.") {
- const uint8_t* actual_data = static_cast<const uint8_t*>(arg);
+ACTION_P(VoidPointerMatchesBytes, expected_data) {
+ const uint8_t* actual_data = static_cast<const uint8_t*>(arg0);
for (size_t i = 0; i < expected_data.size(); ++i) {
- if (actual_data[i] != expected_data[i]) {
- return false;
- }
+ EXPECT_EQ(actual_data[i], expected_data[i]);
}
- return true;
}
} // namespace
@@ -103,54 +102,39 @@ class MdnsSenderTest : public testing::Test {
IPEndpoint ipv6_multicast_endpoint_;
};
-TEST_F(MdnsSenderTest, SendMulticastIPv4) {
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV4);
- MdnsSender sender(socket_info.get());
- socket_info->EnqueueSendResult(Error::Code::kNone);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0);
- EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{0});
-}
-
-TEST_F(MdnsSenderTest, SendMulticastIPv6) {
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV6);
- MdnsSender sender(socket_info.get());
- socket_info->EnqueueSendResult(Error::Code::kNone);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0);
+TEST_F(MdnsSenderTest, SendMulticast) {
+ StrictMock<MockUdpSocket> socket;
+ EXPECT_CALL(socket, IsIPv4()).WillRepeatedly(Return(true));
+ EXPECT_CALL(socket, IsIPv6()).WillRepeatedly(Return(true));
+ MdnsSender sender(&socket);
+ EXPECT_CALL(socket, SendMessage(_, kQueryBytes.size(), _))
+ .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kQueryBytes)));
EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{0});
}
TEST_F(MdnsSenderTest, SendUnicastIPv4) {
IPEndpoint endpoint{.address = IPAddress{192, 168, 1, 1}, .port = 31337};
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV4);
- MdnsSender sender(socket_info.get());
- socket_info->EnqueueSendResult(Error::Code::kNone);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0);
- EXPECT_EQ(sender.SendUnicast(response_message_, endpoint),
+ StrictMock<MockUdpSocket> socket;
+ MdnsSender sender(&socket);
+ EXPECT_CALL(socket, SendMessage(_, kResponseBytes.size(), _))
+ .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kResponseBytes)));
+ EXPECT_EQ(sender.SendMessage(response_message_, endpoint),
Error::Code::kNone);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{0});
}
TEST_F(MdnsSenderTest, SendUnicastIPv6) {
- constexpr uint8_t kIPv6AddressBytes[] = {
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ constexpr uint16_t kIPv6AddressHextets[] = {
+ 0xfe80, 0x0000, 0x0000, 0x0000, 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
- IPEndpoint endpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337};
-
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV6);
- MdnsSender sender(socket_info.get());
- socket_info->EnqueueSendResult(Error::Code::kNone);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0);
- EXPECT_EQ(sender.SendUnicast(response_message_, endpoint),
+ IPEndpoint endpoint{.address = IPAddress(kIPv6AddressHextets), .port = 31337};
+
+ StrictMock<MockUdpSocket> socket;
+ MdnsSender sender(&socket);
+ EXPECT_CALL(socket, SendMessage(_, kResponseBytes.size(), _))
+ .WillOnce(WithArgs<0>(VoidPointerMatchesBytes(kResponseBytes)));
+ EXPECT_EQ(sender.SendMessage(response_message_, endpoint),
Error::Code::kNone);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{0});
}
TEST_F(MdnsSenderTest, MessageTooBig) {
@@ -159,25 +143,24 @@ TEST_F(MdnsSenderTest, MessageTooBig) {
big_message_.AddQuestion(a_question_);
big_message_.AddAnswer(a_record_);
}
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV4);
- MdnsSender sender(socket_info.get());
- socket_info->EnqueueSendResult(Error::Code::kNone);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, _)).Times(0);
+
+ StrictMock<MockUdpSocket> socket;
+ EXPECT_CALL(socket, IsIPv4()).WillRepeatedly(Return(true));
+ EXPECT_CALL(socket, IsIPv6()).WillRepeatedly(Return(true));
+ MdnsSender sender(&socket);
EXPECT_EQ(sender.SendMulticast(big_message_),
Error::Code::kInsufficientBuffer);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{1});
}
TEST_F(MdnsSenderTest, ReturnsErrorOnSocketFailure) {
- std::unique_ptr<FakeUdpSocket> socket_info =
- FakeUdpSocket::CreateDefault(IPAddress::Version::kV4);
- MdnsSender sender(socket_info.get());
+ FakeUdpSocket::MockClient socket_client;
+ FakeUdpSocket socket(nullptr, &socket_client);
+ MdnsSender sender(&socket);
Error error = Error(Error::Code::kConnectionFailed, "error message");
- socket_info->EnqueueSendResult(error);
- EXPECT_CALL(*socket_info->client_mock(), OnSendError(_, error)).Times(1);
+ socket.EnqueueSendResult(error);
+ EXPECT_CALL(socket_client, OnSendError(_, error)).Times(1);
EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone);
- EXPECT_EQ(socket_info->send_queue_size(), size_t{0});
+ EXPECT_EQ(socket.send_queue_size(), size_t{0});
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc
index df0fdb9f977..61762510dbb 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc
@@ -4,34 +4,160 @@
#include "discovery/mdns/mdns_service_impl.h"
+#include <memory>
+
+#include "discovery/common/config.h"
+#include "discovery/common/reporting_client.h"
+#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/public/mdns_constants.h"
+
namespace openscreen {
namespace discovery {
// static
-std::unique_ptr<MdnsService> MdnsService::Create(TaskRunner* task_runner) {
- return std::make_unique<MdnsServiceImpl>();
+std::unique_ptr<MdnsService> MdnsService::Create(
+ TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config,
+ NetworkInterfaceIndex network_interface,
+ SupportedNetworkAddressFamily supported_address_types) {
+ return std::make_unique<MdnsServiceImpl>(
+ task_runner, Clock::now, reporting_client, config, network_interface,
+ supported_address_types);
+}
+
+MdnsServiceImpl::MdnsServiceImpl(
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config,
+ NetworkInterfaceIndex network_interface,
+ SupportedNetworkAddressFamily supported_address_types)
+ : task_runner_(task_runner),
+ now_function_(now_function),
+ reporting_client_(reporting_client) {
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(reporting_client_);
+ OSP_DCHECK(supported_address_types);
+
+ // Create all UDP sockets needed for this object. They should not yet be bound
+ // so that they do not send or receive data until the objects on which their
+ // callback depends is initialized.
+ if (supported_address_types & kUseIpV4Multicast) {
+ ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create(
+ task_runner, this, kDefaultMulticastGroupIPv4Endpoint);
+ OSP_DCHECK(!socket.is_error());
+ OSP_DCHECK(socket.value().get());
+ OSP_DCHECK(socket.value()->IsIPv4());
+
+ socket_v4_ = std::move(socket.value());
+ socket_v4_->SetMulticastOutboundInterface(network_interface);
+ socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4,
+ network_interface);
+ socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4,
+ network_interface);
+ }
+
+ if (supported_address_types & kUseIpV6Multicast) {
+ ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create(
+ task_runner, this, kDefaultMulticastGroupIPv6Endpoint);
+ OSP_DCHECK(!socket.is_error());
+ OSP_DCHECK(socket.value().get());
+ OSP_DCHECK(socket.value()->IsIPv6());
+
+ socket_v6_ = std::move(socket.value());
+ socket_v6_->SetMulticastOutboundInterface(network_interface);
+ socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6,
+ network_interface);
+ socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6,
+ network_interface);
+ }
+
+ // Initialize objects which depend on the above sockets.
+ UdpSocket* socket_ptr =
+ socket_v4_.get() ? socket_v4_.get() : socket_v6_.get();
+ OSP_DCHECK(socket_ptr);
+ sender_ = std::make_unique<MdnsSender>(socket_ptr);
+ if (config.enable_querying) {
+ querier_ = std::make_unique<MdnsQuerier>(
+ sender_.get(), &receiver_, task_runner_, now_function_, &random_delay_,
+ reporting_client_, config);
+ }
+ if (config.enable_publication) {
+ probe_manager_ = std::make_unique<MdnsProbeManagerImpl>(
+ sender_.get(), &receiver_, &random_delay_, task_runner_, now_function_);
+ publisher_ =
+ std::make_unique<MdnsPublisher>(sender_.get(), probe_manager_.get(),
+ task_runner_, now_function_, config);
+ responder_ = std::make_unique<MdnsResponder>(
+ publisher_.get(), probe_manager_.get(), sender_.get(), &receiver_,
+ task_runner_, &random_delay_);
+ }
+
+ receiver_.Start();
+
+ // Initialize all sockets to start sending/receiving data. Now that the above
+ // objects have all been created, it they should be able to safely do so.
+ // NOTE: Although only one of these sockets is used for sending, both will be
+ // used for reading on the mDNS v4 and v6 addresses and ports.
+ if (socket_v4_.get()) {
+ socket_v4_->Bind();
+ }
+ if (socket_v6_.get()) {
+ socket_v6_->Bind();
+ }
}
+MdnsServiceImpl::~MdnsServiceImpl() = default;
+
void MdnsServiceImpl::StartQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) {
- // TODO(yakimakha): Implement this method
+ return querier_->StartQuery(name, dns_type, dns_class, callback);
}
void MdnsServiceImpl::StopQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) {
- // TODO(yakimakha): Implement this method
+ return querier_->StopQuery(name, dns_type, dns_class, callback);
+}
+
+void MdnsServiceImpl::ReinitializeQueries(const DomainName& name) {
+ querier_->ReinitializeQueries(name);
+}
+
+Error MdnsServiceImpl::StartProbe(MdnsDomainConfirmedProvider* callback,
+ DomainName requested_name,
+ IPAddress address) {
+ return probe_manager_->StartProbe(callback, std::move(requested_name),
+ std::move(address));
+}
+
+Error MdnsServiceImpl::RegisterRecord(const MdnsRecord& record) {
+ return publisher_->RegisterRecord(record);
+}
+
+Error MdnsServiceImpl::UpdateRegisteredRecord(const MdnsRecord& old_record,
+ const MdnsRecord& new_record) {
+ return publisher_->UpdateRegisteredRecord(old_record, new_record);
+}
+
+Error MdnsServiceImpl::UnregisterRecord(const MdnsRecord& record) {
+ return publisher_->UnregisterRecord(record);
+}
+
+void MdnsServiceImpl::OnError(UdpSocket* socket, Error error) {
+ reporting_client_->OnFatalError(error);
}
-void MdnsServiceImpl::RegisterRecord(const MdnsRecord& record) {
- // TODO(yakimakha): Implement this method
+void MdnsServiceImpl::OnSendError(UdpSocket* socket, Error error) {
+ sender_->OnSendError(socket, error);
}
-void MdnsServiceImpl::DeregisterRecord(const MdnsRecord& record) {
- // TODO(yakimakha): Implement this method
+void MdnsServiceImpl::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) {
+ receiver_.OnRead(socket, std::move(packet));
}
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h
index 072add12e23..7da1c1a9c24 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h
@@ -5,26 +5,86 @@
#ifndef DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_
#define DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_
+#include "discovery/mdns/mdns_domain_confirmed_provider.h"
+#include "discovery/mdns/mdns_probe_manager.h"
+#include "discovery/mdns/mdns_publisher.h"
+#include "discovery/mdns/mdns_querier.h"
+#include "discovery/mdns/mdns_random.h"
+#include "discovery/mdns/mdns_reader.h"
+#include "discovery/mdns/mdns_receiver.h"
+#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/mdns_responder.h"
+#include "discovery/mdns/mdns_sender.h"
+#include "discovery/mdns/mdns_writer.h"
+#include "discovery/mdns/public/mdns_constants.h"
#include "discovery/mdns/public/mdns_service.h"
+#include "platform/api/udp_socket.h"
namespace openscreen {
+
+class TaskRunner;
+
namespace discovery {
-class MdnsServiceImpl : public MdnsService {
+struct Config;
+class NetworkConfig;
+class ReportingClient;
+
+class MdnsServiceImpl : public MdnsService, public UdpSocket::Client {
public:
+ // |task_runner|, |reporting_client|, and |config| must exist for the duration
+ // of this instance's life.
+ MdnsServiceImpl(TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config,
+ NetworkInterfaceIndex network_interface,
+ SupportedNetworkAddressFamily supported_address_types);
+ ~MdnsServiceImpl() override;
+
+ // MdnsService Overrides.
void StartQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) override;
-
void StopQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) override;
+ void ReinitializeQueries(const DomainName& name) override;
+ Error StartProbe(MdnsDomainConfirmedProvider* callback,
+ DomainName requested_name,
+ IPAddress address) override;
+
+ Error RegisterRecord(const MdnsRecord& record) override;
+ Error UpdateRegisteredRecord(const MdnsRecord& old_record,
+ const MdnsRecord& new_record) override;
+ Error UnregisterRecord(const MdnsRecord& record) override;
+
+ // UdpSocket::Client overrides.
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
+
+ private:
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+ ReportingClient* const reporting_client_;
+
+ MdnsRandom random_delay_;
+ MdnsReceiver receiver_;
- void RegisterRecord(const MdnsRecord& record) override;
+ // Sockets to send and receive mDNS Data according to RFC 6762.
+ std::unique_ptr<UdpSocket> socket_v4_;
+ std::unique_ptr<UdpSocket> socket_v6_;
- void DeregisterRecord(const MdnsRecord& record) override;
+ // unique_ptrs are used for the below objects so that they can be initialized
+ // in the body of the ctor, after send_socket is initialized.
+ std::unique_ptr<MdnsSender> sender_;
+ std::unique_ptr<MdnsQuerier> querier_;
+ std::unique_ptr<MdnsProbeManagerImpl> probe_manager_;
+ std::unique_ptr<MdnsPublisher> publisher_;
+ std::unique_ptr<MdnsResponder> responder_;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc
index 28b05b7fa73..e250582197b 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc
@@ -5,7 +5,9 @@
#include "discovery/mdns/mdns_trackers.h"
#include <array>
+#include <limits>
+#include "discovery/common/config.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
@@ -39,6 +41,18 @@ bool IsGoodbyeRecord(const MdnsRecord& record) {
return record.ttl() == std::chrono::seconds{0};
}
+bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) {
+ if (record.dns_type() != DnsType::kNSEC) {
+ return false;
+ }
+
+ const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types();
+ return std::find_if(nsec_types.begin(), nsec_types.end(),
+ [dns_type](DnsType type) {
+ return type == dns_type || type == DnsType::kANY;
+ }) != nsec_types.end();
+}
+
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
@@ -49,46 +63,131 @@ constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
MdnsTracker::MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay)
+ MdnsRandom* random_delay,
+ TrackerType tracker_type)
: sender_(sender),
task_runner_(task_runner),
now_function_(now_function),
send_alarm_(now_function, task_runner),
- random_delay_(random_delay) {
- OSP_DCHECK(task_runner);
- OSP_DCHECK(now_function);
- OSP_DCHECK(random_delay);
- OSP_DCHECK(sender);
+ random_delay_(random_delay),
+ tracker_type_(tracker_type) {
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(now_function_);
+ OSP_DCHECK(random_delay_);
+ OSP_DCHECK(sender_);
+}
+
+MdnsTracker::~MdnsTracker() {
+ send_alarm_.Cancel();
+
+ for (const MdnsTracker* node : adjacent_nodes_) {
+ node->RemovedReverseAdjacency(this);
+ }
+}
+
+bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const {
+ OSP_DCHECK(node);
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
+ if (it != adjacent_nodes_.end()) {
+ return false;
+ }
+
+ adjacent_nodes_.push_back(node);
+ node->AddReverseAdjacency(this);
+ return true;
+}
+
+bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const {
+ OSP_DCHECK(node);
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+
+ auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
+ if (it == adjacent_nodes_.end()) {
+ return false;
+ }
+
+ adjacent_nodes_.erase(it);
+ node->RemovedReverseAdjacency(this);
+ return true;
+}
+
+void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const {
+ OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) ==
+ adjacent_nodes_.end());
+
+ adjacent_nodes_.push_back(node);
+}
+
+void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const {
+ auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
+ OSP_DCHECK(it != adjacent_nodes_.end());
+
+ adjacent_nodes_.erase(it);
}
MdnsRecordTracker::MdnsRecordTracker(
MdnsRecord record,
+ DnsType dns_type,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
- std::function<void(const MdnsRecord&)> record_expired_callback)
- : MdnsTracker(sender, task_runner, now_function, random_delay),
+ RecordExpiredCallback record_expired_callback)
+ : MdnsTracker(sender,
+ task_runner,
+ now_function,
+ random_delay,
+ TrackerType::kRecordTracker),
record_(std::move(record)),
+ dns_type_(dns_type),
start_time_(now_function_()),
- record_expired_callback_(record_expired_callback) {
- OSP_DCHECK(record_expired_callback);
+ record_expired_callback_(std::move(record_expired_callback)) {
+ OSP_DCHECK(record_expired_callback_);
+
+ // RecordTrackers cannot be created for tracking NSEC types or ANY types.
+ OSP_DCHECK(dns_type_ != DnsType::kNSEC);
+ OSP_DCHECK(dns_type_ != DnsType::kANY);
- send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
+ // Validate that, if the provided |record| is an NSEC record, then it provides
+ // a negative response for |dns_type|.
+ OSP_DCHECK(record_.dns_type() != DnsType::kNSEC ||
+ IsNegativeResponseForType(record_, dns_type_));
+
+ ScheduleFollowUpQuery();
}
+MdnsRecordTracker::~MdnsRecordTracker() = default;
+
ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
const MdnsRecord& new_record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- bool has_same_rdata = record_.rdata() == new_record.rdata();
+ const bool has_same_rdata = record_.dns_type() == new_record.dns_type() &&
+ record_.rdata() == new_record.rdata();
+ const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC;
+ const bool current_is_negative_response =
+ record_.dns_type() == DnsType::kNSEC;
+
+ if ((record_.dns_class() != new_record.dns_class()) ||
+ (record_.name() != new_record.name())) {
+ // The new record has been passed to a wrong tracker.
+ return Error::Code::kParameterInvalid;
+ }
+
+ // New response record must correspond to the correct type.
+ if ((!new_is_negative_response && new_record.dns_type() != dns_type_) ||
+ (new_is_negative_response &&
+ !IsNegativeResponseForType(new_record, dns_type_))) {
+ // The new record has been passed to a wrong tracker.
+ return Error::Code::kParameterInvalid;
+ }
// Goodbye records must have the same RDATA but TTL of 0.
- // RFC 6762 Section 10.1
+ // RFC 6762 Section 10.1.
// https://tools.ietf.org/html/rfc6762#section-10.1
- if ((record_.dns_type() != new_record.dns_type()) ||
- (record_.dns_class() != new_record.dns_class()) ||
- (record_.name() != new_record.name()) ||
- (IsGoodbyeRecord(new_record) && !has_same_rdata)) {
+ if (!new_is_negative_response && !current_is_negative_response &&
+ IsGoodbyeRecord(new_record) && !has_same_rdata) {
// The new record has been passed to a wrong tracker.
return Error::Code::kParameterInvalid;
}
@@ -99,9 +198,9 @@ ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
new_record.dns_class(), new_record.record_type(),
kGoodbyeRecordTtl, new_record.rdata());
- // Goodbye records do not need to be requeried, set the attempt count to the
- // last item, which is 100% of TTL, i.e. record expiration.
- attempt_count_ = openscreen::countof(kTtlFractions) - 1;
+ // Goodbye records do not need to be re-queried, set the attempt count to
+ // the last item, which is 100% of TTL, i.e. record expiration.
+ attempt_count_ = countof(kTtlFractions) - 1;
} else {
record_ = new_record;
attempt_count_ = 0;
@@ -109,11 +208,21 @@ ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
}
start_time_ = now_function_();
- send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
+ ScheduleFollowUpQuery();
return result;
}
+bool MdnsRecordTracker::AddAssociatedQuery(
+ const MdnsQuestionTracker* question_tracker) const {
+ return AddAdjacentNode(question_tracker);
+}
+
+bool MdnsRecordTracker::RemoveAssociatedQuery(
+ const MdnsQuestionTracker* question_tracker) const {
+ return RemoveAdjacentNode(question_tracker);
+}
+
void MdnsRecordTracker::ExpireSoon() {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
@@ -122,35 +231,55 @@ void MdnsRecordTracker::ExpireSoon() {
record_.record_type(), kGoodbyeRecordTtl, record_.rdata());
// Set the attempt count to the last item, which is 100% of TTL, i.e. record
- // expiration, to prevent any requeries
- attempt_count_ = openscreen::countof(kTtlFractions) - 1;
+ // expiration, to prevent any re-queries
+ attempt_count_ = countof(kTtlFractions) - 1;
start_time_ = now_function_();
- send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
+ ScheduleFollowUpQuery();
+}
+
+void MdnsRecordTracker::ExpireNow() {
+ record_expired_callback_(this, record_);
+}
+
+bool MdnsRecordTracker::IsNearingExpiry() const {
+ return (now_function_() - start_time_) > record_.ttl() / 2;
}
-void MdnsRecordTracker::SendQuery() {
+bool MdnsRecordTracker::SendQuery() const {
const Clock::time_point expiration_time = start_time_ + record_.ttl();
- const bool is_expired = (now_function_() >= expiration_time);
+ bool is_expired = (now_function_() >= expiration_time);
if (!is_expired) {
- MdnsQuestion question(record_.name(), record_.dns_type(),
- record_.dns_class(), ResponseType::kMulticast);
- MdnsMessage message(CreateMessageId(), MessageType::Query);
- message.AddQuestion(std::move(question));
- sender_->SendMulticast(message);
- send_alarm_.Schedule([this] { MdnsRecordTracker::SendQuery(); },
- GetNextSendTime());
+ for (const MdnsTracker* tracker : adjacent_nodes()) {
+ tracker->SendQuery();
+ }
} else {
- record_expired_callback_(record_);
+ record_expired_callback_(this, record_);
}
+
+ return !is_expired;
+}
+
+void MdnsRecordTracker::ScheduleFollowUpQuery() {
+ send_alarm_.Schedule(
+ [this] {
+ if (SendQuery()) {
+ ScheduleFollowUpQuery();
+ }
+ },
+ GetNextSendTime());
+}
+
+std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const {
+ return {record_};
}
-openscreen::platform::Clock::time_point MdnsRecordTracker::GetNextSendTime() {
- OSP_DCHECK(attempt_count_ < openscreen::countof(kTtlFractions));
+Clock::time_point MdnsRecordTracker::GetNextSendTime() {
+ OSP_DCHECK(attempt_count_ < countof(kTtlFractions));
double ttl_fraction = kTtlFractions[attempt_count_++];
// Do not add random variation to the expiration time (last fraction of TTL)
- if (attempt_count_ != openscreen::countof(kTtlFractions)) {
+ if (attempt_count_ != countof(kTtlFractions)) {
ttl_fraction += random_delay_->GetRecordTtlVariation();
}
@@ -163,25 +292,134 @@ MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay)
- : MdnsTracker(sender, task_runner, now_function, random_delay),
+ MdnsRandom* random_delay,
+ const Config& config,
+ QueryType query_type)
+ : MdnsTracker(sender,
+ task_runner,
+ now_function,
+ random_delay,
+ TrackerType::kQuestionTracker),
question_(std::move(question)),
- send_delay_(kMinimumQueryInterval) {
+ send_delay_(kMinimumQueryInterval),
+ query_type_(query_type),
+ maximum_announcement_count_(config.new_query_announcement_count < 0
+ ? INT_MAX
+ : config.new_query_announcement_count) {
+ // Initialize the last send time to time_point::min() so that the next call to
+ // SendQuery() is guaranteed to query the network.
+ last_send_time_ = TrivialClockTraits::time_point::min();
+
// The initial query has to be sent after a random delay of 20-120
// milliseconds.
- const Clock::duration delay = random_delay_->GetInitialQueryDelay();
- send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
- now_function_() + delay);
+ if (announcements_so_far_ < maximum_announcement_count_) {
+ announcements_so_far_++;
+
+ if (query_type_ == QueryType::kOneShot) {
+ task_runner_->PostTask([this] { MdnsQuestionTracker::SendQuery(); });
+ } else {
+ OSP_DCHECK(query_type_ == QueryType::kContinuous);
+ send_alarm_.ScheduleFromNow(
+ [this]() {
+ MdnsQuestionTracker::SendQuery();
+ ScheduleFollowUpQuery();
+ },
+ random_delay_->GetInitialQueryDelay());
+ }
+ }
+}
+
+MdnsQuestionTracker::~MdnsQuestionTracker() = default;
+
+bool MdnsQuestionTracker::AddAssociatedRecord(
+ const MdnsRecordTracker* record_tracker) const {
+ return AddAdjacentNode(record_tracker);
}
-void MdnsQuestionTracker::SendQuery() {
+bool MdnsQuestionTracker::RemoveAssociatedRecord(
+ const MdnsRecordTracker* record_tracker) const {
+ return RemoveAdjacentNode(record_tracker);
+}
+
+std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
+ std::vector<MdnsRecord> records;
+ for (const MdnsTracker* tracker : adjacent_nodes()) {
+ OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
+
+ // This call cannot result in an infinite loop because MdnsRecordTracker
+ // instances only return a single record from this call.
+ std::vector<MdnsRecord> node_records = tracker->GetRecords();
+ OSP_DCHECK(node_records.size() == 1);
+
+ records.push_back(std::move(node_records[0]));
+ }
+
+ return records;
+}
+
+bool MdnsQuestionTracker::SendQuery() const {
+ // NOTE: The RFC does not specify the minimum interval between queries for
+ // multiple records of the same query when initiated for different reasons
+ // (such as for different record refreshes or for one record refresh and the
+ // periodic re-querying for a continuous query). For this reason, a constant
+ // outside of scope of the RFC has been chosen.
+ TrivialClockTraits::time_point now = now_function_();
+ if (now < last_send_time_ + kMinimumQueryInterval) {
+ return true;
+ }
+ last_send_time_ = now;
+
MdnsMessage message(CreateMessageId(), MessageType::Query);
message.AddQuestion(question_);
- // TODO(yakimakha): Implement known-answer suppression by adding known
- // answers to the question
+
+ // Send the message and additional known answer packets as needed.
+ for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
+ OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
+
+ const MdnsRecordTracker* record_tracker =
+ static_cast<const MdnsRecordTracker*>(*it);
+ if (record_tracker->IsNearingExpiry()) {
+ it++;
+ continue;
+ }
+
+ // A record tracker should only contain one record.
+ std::vector<MdnsRecord> node_records = (*it)->GetRecords();
+ OSP_DCHECK(node_records.size() == 1);
+ MdnsRecord node_record = std::move(node_records[0]);
+
+ if (message.CanAddRecord(node_record)) {
+ message.AddAnswer(std::move(node_record));
+ it++;
+ } else if (message.questions().empty() && message.answers().empty()) {
+ // This case should never happen, because it means a record is too large
+ // to fit into its own message.
+ OSP_LOG << "Encountered unreasonably large message in cache. Skipping "
+ << "known answer in suppressions...";
+ it++;
+ } else {
+ message.set_truncated();
+ sender_->SendMulticast(message);
+ message = MdnsMessage(CreateMessageId(), MessageType::Query);
+ }
+ }
sender_->SendMulticast(message);
- send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
- now_function_() + send_delay_);
+ return true;
+}
+
+void MdnsQuestionTracker::ScheduleFollowUpQuery() {
+ if (announcements_so_far_ >= maximum_announcement_count_) {
+ return;
+ }
+ announcements_so_far_++;
+
+ send_alarm_.ScheduleFromNow(
+ [this] {
+ if (SendQuery()) {
+ ScheduleFollowUpQuery();
+ }
+ },
+ send_delay_);
send_delay_ = send_delay_ * kIntervalIncreaseFactor;
if (send_delay_ > kMaximumQueryInterval) {
send_delay_ = kMaximumQueryInterval;
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h
index 75556f2c886..6cb863a94a1 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.h
@@ -6,27 +6,41 @@
#define DISCOVERY_MDNS_MDNS_TRACKERS_H_
#include <unordered_map>
+#include <vector>
#include "absl/hash/hash.h"
#include "discovery/mdns/mdns_records.h"
#include "platform/api/task_runner.h"
#include "platform/base/error.h"
+#include "platform/base/trivial_clock_traits.h"
#include "util/alarm.h"
namespace openscreen {
namespace discovery {
+struct Config;
class MdnsRandom;
class MdnsRecord;
class MdnsRecordChangedCallback;
class MdnsSender;
// MdnsTracker is a base class for MdnsRecordTracker and MdnsQuestionTracker for
-// the purposes of common code sharing only
+// the purposes of common code sharing only.
+//
+// Instances of this class represent nodes of a bidirectional graph, such that
+// if node A is adjacent to node B, B is also adjacent to A. In this class, the
+// adjacent nodes are stored in adjacency list |associated_tracker_|, and
+// exposed methods to add and remove nodes from this list also modify the added
+// or removed node to remove this instance from its adjacency list.
+//
+// Because MdnsQuestionTracker::AddAssocaitedRecord() can only called on
+// MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is
+// only called on MdnsQuestionTracker objects, this created graph is bipartite.
+// This means that MdnsRecordTracker objects are only adjacent to
+// MdnsQuestionTracker objects and the opposite.
class MdnsTracker {
public:
- using ClockNowFunctionPtr = openscreen::platform::ClockNowFunctionPtr;
- using TaskRunner = openscreen::platform::TaskRunner;
+ enum class TrackerType { kRecordTracker, kQuestionTracker };
// MdnsTracker does not own |sender|, |task_runner| and |random_delay|
// and expects that the lifetime of these objects exceeds the lifetime of
@@ -34,34 +48,73 @@ class MdnsTracker {
MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay);
+ MdnsRandom* random_delay,
+ TrackerType tracker_type);
MdnsTracker(const MdnsTracker& other) = delete;
MdnsTracker(MdnsTracker&& other) noexcept = delete;
MdnsTracker& operator=(const MdnsTracker& other) = delete;
MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete;
- ~MdnsTracker() = default;
+ virtual ~MdnsTracker();
+
+ // Returns the record type represented by this tracker.
+ TrackerType tracker_type() const { return tracker_type_; }
+
+ // Sends a query message via MdnsSender. Returns false if a follow up query
+ // should NOT be scheduled and true otherwise.
+ virtual bool SendQuery() const = 0;
+
+ // Returns the records currently associated with this tracker.
+ virtual std::vector<MdnsRecord> GetRecords() const = 0;
protected:
+ // Schedules a repeat query to be sent out.
+ virtual void ScheduleFollowUpQuery() = 0;
+
+ // These methods create a bidirectional adjacency with another node in the
+ // graph.
+ bool AddAdjacentNode(const MdnsTracker* tracker) const;
+ bool RemoveAdjacentNode(const MdnsTracker* tracker) const;
+
+ const std::vector<const MdnsTracker*>& adjacent_nodes() const {
+ return adjacent_nodes_;
+ }
+
MdnsSender* const sender_;
TaskRunner* const task_runner_;
const ClockNowFunctionPtr now_function_;
Alarm send_alarm_; // TODO(yakimakha): Use cancelable task when available
MdnsRandom* const random_delay_;
+ TrackerType tracker_type_;
+
+ private:
+ // These methods are used to ensure the bidirectional-ness of this graph.
+ void AddReverseAdjacency(const MdnsTracker* tracker) const;
+ void RemovedReverseAdjacency(const MdnsTracker* tracker) const;
+
+ // Adjacency list for this graph node.
+ mutable std::vector<const MdnsTracker*> adjacent_nodes_;
};
+class MdnsQuestionTracker;
+
// MdnsRecordTracker manages automatic resending of mDNS queries for
// refreshing records as they reach their expiration time.
class MdnsRecordTracker : public MdnsTracker {
public:
- using Clock = openscreen::platform::Clock;
+ using RecordExpiredCallback =
+ std::function<void(const MdnsRecordTracker*, const MdnsRecord&)>;
+
+ // NOTE: In the case that |record| is of type NSEC, |dns_type| is expected to
+ // differ from |record|'s type.
+ MdnsRecordTracker(MdnsRecord record,
+ DnsType dns_type,
+ MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay,
+ RecordExpiredCallback record_expired_callback);
- MdnsRecordTracker(
- MdnsRecord record,
- MdnsSender* sender,
- TaskRunner* task_runner,
- ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay,
- std::function<void(const MdnsRecord&)> record_expired_callback);
+ ~MdnsRecordTracker() override;
// Possible outcomes from updating a tracked record.
enum class UpdateType {
@@ -77,45 +130,108 @@ class MdnsRecordTracker : public MdnsTracker {
// for the current tracked record.
ErrorOr<UpdateType> Update(const MdnsRecord& new_record);
+ // Adds or removed a question which this record answers.
+ bool AddAssociatedQuery(const MdnsQuestionTracker* question_tracker) const;
+ bool RemoveAssociatedQuery(const MdnsQuestionTracker* question_tracker) const;
+
// Sets record to expire after 1 seconds as per RFC 6762
void ExpireSoon();
- // Returns a reference to the tracked record.
- const MdnsRecord& record() const { return record_; }
+ // Expires the record now
+ void ExpireNow();
+
+ // Returns true if half of the record's TTL has passed, and false otherwise.
+ // Half is used due to specifications in RFC 6762 section 7.1.
+ bool IsNearingExpiry() const;
+
+ // Returns information about the stored record.
+ //
+ // NOTE: These methods are NOT all pass-through methods to |record_|.
+ // specifically, dns_type() returns the DNS Type associated with this record
+ // tracker, which may be different from the record type if |record_| is of
+ // type NSEC. To avoid this case, direct access to the underlying |record_|
+ // instance is not provided.
+ //
+ // In this case, creating an MdnsRecord with the below data will result in a
+ // runtime error due to DCHECKS and that Rdata's associated type will not
+ // match DnsType when |record_| is of type NSEC. Therefore, creating such
+ // records should be guarded by is_negative_response() checks.
+ const DomainName& name() const { return record_.name(); }
+ DnsType dns_type() const { return dns_type_; }
+ DnsClass dns_class() const { return record_.dns_class(); }
+ RecordType record_type() const { return record_.record_type(); }
+ std::chrono::seconds ttl() const { return record_.ttl(); }
+ const Rdata& rdata() const { return record_.rdata(); }
+
+ bool is_negative_response() const {
+ return record_.dns_type() == DnsType::kNSEC;
+ }
private:
- void SendQuery();
+ using MdnsTracker::tracker_type;
+
+ // Needed to provide the test class access to the record stored in this
+ // tracker.
+ friend class MdnsTrackerTest;
+
Clock::time_point GetNextSendTime();
+ // MdnsTracker overrides.
+ bool SendQuery() const override;
+ void ScheduleFollowUpQuery() override;
+ std::vector<MdnsRecord> GetRecords() const override;
+
// Stores MdnsRecord provided to Start method call.
MdnsRecord record_;
+
+ // DnsType this record tracker represents. This may not match the type of
+ // |record_| if it is an NSEC record.
+ const DnsType dns_type_;
+
// A point in time when the record was received and the tracking has started.
Clock::time_point start_time_;
+
// Number of times record refresh has been attempted.
size_t attempt_count_ = 0;
- std::function<void(const MdnsRecord&)> record_expired_callback_;
+ RecordExpiredCallback record_expired_callback_;
};
// MdnsQuestionTracker manages automatic resending of mDNS queries for
-// continuous monitoring with exponential back-off as described in RFC 6762
+// continuous monitoring with exponential back-off as described in RFC 6762.
class MdnsQuestionTracker : public MdnsTracker {
public:
- using Clock = openscreen::platform::Clock;
+ // Supported query types, per RFC 6762 section 5.
+ enum class QueryType { kOneShot, kContinuous };
MdnsQuestionTracker(MdnsQuestion question,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay);
+ MdnsRandom* random_delay,
+ const Config& config,
+ QueryType query_type = QueryType::kContinuous);
+
+ ~MdnsQuestionTracker() override;
+
+ // Adds or removed an answer to a the question posed by this tracker.
+ bool AddAssociatedRecord(const MdnsRecordTracker* record_tracker) const;
+ bool RemoveAssociatedRecord(const MdnsRecordTracker* record_tracker) const;
// Returns a reference to the tracked question.
const MdnsQuestion& question() const { return question_; }
private:
+ using MdnsTracker::tracker_type;
+
using RecordKey = std::tuple<DomainName, DnsType, DnsClass>;
- // Sends a query message via MdnsSender and schedules the next resend.
- void SendQuery();
+ // Determines if all answers to this query have been received.
+ bool HasReceivedAllResponses();
+
+ // MdnsTracker overrides.
+ bool SendQuery() const override;
+ void ScheduleFollowUpQuery() override;
+ std::vector<MdnsRecord> GetRecords() const override;
// Stores MdnsQuestion provided to Start method call.
MdnsQuestion question_;
@@ -123,15 +239,18 @@ class MdnsQuestionTracker : public MdnsTracker {
// A delay between the currently scheduled and the next queries.
Clock::duration send_delay_;
- // Active record trackers, uniquely identified by domain name, DNS record type
- // and DNS record class. MdnsRecordTracker instances are stored as unique_ptr
- // so they are not moved around in memory when the collection is modified.
- // This allows passing a pointer to MdnsRecordTracker to a task running on the
- // TaskRunner.
- std::unordered_map<RecordKey,
- std::unique_ptr<MdnsRecordTracker>,
- absl::Hash<RecordKey>>
- record_trackers_;
+ // Last time that this tracker's question was asked.
+ mutable TrivialClockTraits::time_point last_send_time_;
+
+ // Specifies whether this query is intended to be a one-shot query, as defined
+ // in RFC 6762 section 5.1.
+ const QueryType query_type_;
+
+ // Signifies the maximum number of times a record should be announced.
+ int maximum_announcement_count_;
+
+ // Number of times this query has been announced.
+ int announcements_so_far_ = 0;
};
} // namespace discovery
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc
index 91b33678b91..399e0d46145 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc
@@ -4,6 +4,10 @@
#include "discovery/mdns/mdns_trackers.h"
+#include <memory>
+#include <utility>
+
+#include "discovery/common/config.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
@@ -11,51 +15,49 @@
#include "gtest/gtest.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
+#include "platform/test/fake_udp_socket.h"
namespace openscreen {
namespace discovery {
+namespace {
+
+constexpr Clock::duration kOneSecond =
+ std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1));
+
+}
-using openscreen::platform::Clock;
-using openscreen::platform::FakeClock;
-using openscreen::platform::FakeTaskRunner;
-using openscreen::platform::NetworkInterfaceIndex;
-using openscreen::platform::TaskRunner;
-using openscreen::platform::UdpSocket;
using testing::_;
using testing::Args;
+using testing::DoAll;
using testing::Invoke;
using testing::Return;
+using testing::StrictMock;
using testing::WithArgs;
ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) {
const uint8_t* actual_data = reinterpret_cast<const uint8_t*>(arg0);
const size_t actual_size = arg1;
- EXPECT_EQ(actual_size, expected_size);
+ ASSERT_EQ(actual_size, expected_size);
// Start at bytes[2] to skip a generated message ID.
for (size_t i = 2; i < actual_size; ++i) {
EXPECT_EQ(actual_data[i], expected_data[i]);
}
}
-class MockUdpSocket : public UdpSocket {
+ACTION_P(VerifyTruncated, is_truncated) {
+ EXPECT_EQ(arg0.is_truncated(), is_truncated);
+}
+
+ACTION_P(VerifyRecordCount, record_count) {
+ EXPECT_EQ(arg0.answers().size(), static_cast<size_t>(record_count));
+}
+
+class MockMdnsSender : public MdnsSender {
public:
- MOCK_METHOD(bool, IsIPv4, (), (const, override));
- MOCK_METHOD(bool, IsIPv6, (), (const, override));
- MOCK_METHOD(IPEndpoint, GetLocalEndpoint, (), (const, override));
- MOCK_METHOD(void, Bind, (), (override));
- MOCK_METHOD(void,
- SetMulticastOutboundInterface,
- (NetworkInterfaceIndex),
- (override));
- MOCK_METHOD(void,
- JoinMulticastGroup,
- (const IPAddress&, NetworkInterfaceIndex),
- (override));
- MOCK_METHOD(void,
- SendMessage,
- (const void*, size_t, const IPEndpoint&),
- (override));
- MOCK_METHOD(void, SetDscp, (DscpMode), (override));
+ explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
+
+ MOCK_METHOD1(SendMulticast, Error(const MdnsMessage&));
+ MOCK_METHOD2(SendMessage, Error(const MdnsMessage&, const IPEndpoint&));
};
class MockRecordChangedCallback : public MdnsRecordChangedCallback {
@@ -71,6 +73,7 @@ class MdnsTrackerTest : public testing::Test {
MdnsTrackerTest()
: clock_(Clock::now()),
task_runner_(&clock_),
+ socket_(&task_runner_),
sender_(&socket_),
a_question_(DomainName{"testing", "local"},
DnsType::kANY,
@@ -81,31 +84,64 @@ class MdnsTrackerTest : public testing::Test {
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(120),
- ARecordRdata(IPAddress{172, 0, 0, 1})) {}
+ ARecordRdata(IPAddress{172, 0, 0, 1})),
+ nsec_record_(
+ DomainName{"testing", "local"},
+ DnsType::kNSEC,
+ DnsClass::kIN,
+ RecordType::kShared,
+ std::chrono::seconds(120),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {}
template <class TrackerType>
void TrackerNoQueryAfterDestruction(TrackerType tracker) {
tracker.reset();
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
// Advance fake clock by a long time interval to make sure if there's a
// scheduled task, it will run.
clock_.Advance(std::chrono::hours(1));
}
std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
- const MdnsRecord& record) {
+ const MdnsRecord& record,
+ DnsType type) {
return std::make_unique<MdnsRecordTracker>(
- record, &sender_, &task_runner_, &FakeClock::now, &random_,
- [this](const MdnsRecord& record) { expiration_called_ = true; });
+ record, type, &sender_, &task_runner_, &FakeClock::now, &random_,
+ [this](const MdnsRecordTracker* tracker, const MdnsRecord& record) {
+ expiration_called_ = true;
+ });
+ }
+
+ std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
+ const MdnsRecord& record) {
+ return CreateRecordTracker(record, record.dns_type());
}
std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker(
- const MdnsQuestion& question) {
- return std::make_unique<MdnsQuestionTracker>(
- question, &sender_, &task_runner_, &FakeClock::now, &random_);
+ const MdnsQuestion& question,
+ MdnsQuestionTracker::QueryType query_type =
+ MdnsQuestionTracker::QueryType::kContinuous) {
+ return std::make_unique<MdnsQuestionTracker>(question, &sender_,
+ &task_runner_, &FakeClock::now,
+ &random_, config_, query_type);
}
protected:
+ void AdvanceThroughAllTtlFractions(std::chrono::seconds ttl) {
+ constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
+ Clock::duration time_passed{0};
+ for (double fraction : kTtlFractions) {
+ Clock::duration time_till_refresh =
+ std::chrono::duration_cast<Clock::duration>(ttl * fraction);
+ Clock::duration delta = time_till_refresh - time_passed;
+ time_passed = time_till_refresh;
+ clock_.Advance(delta);
+ }
+ }
+
+ const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) {
+ return tracker->record_;
+ }
+
// clang-format off
const std::vector<uint8_t> kQuestionQueryBytes = {
0x00, 0x00, // ID = 0
@@ -138,14 +174,16 @@ class MdnsTrackerTest : public testing::Test {
};
// clang-format on
+ Config config_;
FakeClock clock_;
FakeTaskRunner task_runner_;
- MockUdpSocket socket_;
- MdnsSender sender_;
+ FakeUdpSocket socket_;
+ StrictMock<MockMdnsSender> sender_;
MdnsRandom random_;
MdnsQuestion a_question_;
MdnsRecord a_record_;
+ MdnsRecord nsec_record_;
bool expiration_called_ = false;
};
@@ -160,32 +198,58 @@ class MdnsTrackerTest : public testing::Test {
TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) {
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
- EXPECT_EQ(tracker->record(), a_record_);
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
}
-TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) {
+TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) {
+ std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker(
+ a_question_, MdnsQuestionTracker::QueryType::kOneShot);
+ std::unique_ptr<MdnsQuestionTracker> question2 = CreateQuestionTracker(
+ a_question_, MdnsQuestionTracker::QueryType::kOneShot);
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(2);
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kOneSecond);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
- // Only expect 4 queries being sent, when record reaches it's TTL it's
- // considered expired and another query is not sent
- constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(4);
-
- Clock::duration time_passed{0};
- for (double fraction : kTtlFractions) {
- Clock::duration time_till_refresh =
- std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * fraction);
- Clock::duration delta = time_till_refresh - time_passed;
- time_passed = time_till_refresh;
- clock_.Advance(delta);
- }
+
+ // No queries without an associated tracker.
+ AdvanceThroughAllTtlFractions(a_record_.ttl());
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // 4 queries with one associated tracker.
+ tracker = CreateRecordTracker(a_record_);
+ tracker->AddAssociatedQuery(question.get());
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(4);
+ AdvanceThroughAllTtlFractions(a_record_.ttl());
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
+ // 8 queries with two associated trackers.
+ tracker = CreateRecordTracker(a_record_);
+ tracker->AddAssociatedQuery(question.get());
+ tracker->AddAssociatedQuery(question2.get());
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(8);
+ AdvanceThroughAllTtlFractions(a_record_.ttl());
}
TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
+ std::unique_ptr<MdnsQuestionTracker> question = CreateQuestionTracker(
+ a_question_, MdnsQuestionTracker::QueryType::kOneShot);
+ EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
+ clock_.Advance(kOneSecond);
+ clock_.Advance(kOneSecond);
+ testing::Mock::VerifyAndClearExpectations(&sender_);
+
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
+ tracker->AddAssociatedQuery(question.get());
- EXPECT_CALL(socket_, SendMessage(_, _, _))
- .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
- kRecordQueryBytes.data(), kRecordQueryBytes.size())));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .Times(1)
+ .WillRepeatedly([this](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{1});
+ EXPECT_EQ(message.questions()[0], a_question_);
+ return Error::None();
+ });
clock_.Advance(
std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.83));
@@ -202,7 +266,6 @@ TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
// no query and instead the record will expire.
// Check lower bound for task being late (TTL) and an arbitrarily long time
// interval to ensure the query is not sent a later time.
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
clock_.Advance(a_record_.ttl());
clock_.Advance(std::chrono::hours(1));
}
@@ -232,6 +295,16 @@ TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) {
EXPECT_TRUE(expiration_called_);
}
+TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ tracker->ExpireSoon();
+ // Expire schedules expiration after 1 second.
+ clock_.Advance(std::chrono::seconds(1));
+ EXPECT_TRUE(expiration_called_);
+}
+
TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) {
expiration_called_ = false;
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
@@ -250,9 +323,6 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
EXPECT_EQ(tracker->Update(goodbye_record).value(),
MdnsRecordTracker::UpdateType::kGoodbye);
- // No refresh queries are sent after goodbye record is received.
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
-
// Advance clock to just before the expiration time of 1 second.
clock_.Advance(std::chrono::microseconds{999999});
EXPECT_FALSE(expiration_called_);
@@ -261,7 +331,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
EXPECT_TRUE(expiration_called_);
}
-TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) {
+TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) {
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(),
@@ -292,6 +362,71 @@ TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) {
Error::Code::kParameterInvalid);
}
+TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(a_record_, DnsType::kA);
+ auto result = tracker->Update(nsec_record_);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+
+ // Check invalid update.
+ MdnsRecord non_a_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
+ tracker = CreateRecordTracker(a_record_, DnsType::kA);
+ auto response = tracker->Update(non_a_nsec_record);
+ ASSERT_TRUE(response.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord multiple_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA,
+ DnsType::kAAAA));
+ auto result = tracker->Update(multiple_nsec_record);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record);
+
+ // Check invalid update.
+ tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord non_a_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
+ auto response = tracker->Update(non_a_nsec_record);
+ EXPECT_TRUE(response.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ auto result = tracker->Update(a_record_);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
+
+ // Check invalid update.
+ tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA,
+ a_record_.dns_class(), a_record_.record_type(),
+ std::chrono::seconds{0},
+ AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1}));
+ result = tracker->Update(aaaa_record);
+ EXPECT_TRUE(result.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+}
+
TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) {
expiration_called_ = false;
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
@@ -316,12 +451,16 @@ TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
std::unique_ptr<MdnsQuestionTracker> tracker =
CreateQuestionTracker(a_question_);
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce(
+ DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None())));
clock_.Advance(std::chrono::milliseconds(120));
std::chrono::seconds interval{1};
while (interval < std::chrono::hours(1)) {
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1);
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce(
+ DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None())));
clock_.Advance(interval);
interval *= 2;
}
@@ -331,9 +470,15 @@ TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
std::unique_ptr<MdnsQuestionTracker> tracker =
CreateQuestionTracker(a_question_);
- EXPECT_CALL(socket_, SendMessage(_, _, _))
- .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
- kQuestionQueryBytes.data(), kQuestionQueryBytes.size())));
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce(DoAll(
+ WithArgs<0>(VerifyTruncated(false)),
+ [this](const MdnsMessage& message) -> Error {
+ EXPECT_EQ(message.questions().size(), size_t{1});
+ EXPECT_EQ(message.questions()[0], a_question_);
+ return Error::None();
+ },
+ Return(Error::None())));
clock_.Advance(std::chrono::milliseconds(120));
}
@@ -344,5 +489,30 @@ TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
TrackerNoQueryAfterDestruction(std::move(tracker));
}
+TEST_F(MdnsTrackerTest, QuestionTrackerSendsMultipleMessages) {
+ std::unique_ptr<MdnsQuestionTracker> tracker =
+ CreateQuestionTracker(a_question_);
+
+ std::vector<std::unique_ptr<MdnsRecordTracker>> answers;
+ for (int i = 0; i < 100; i++) {
+ auto record = CreateRecordTracker(a_record_);
+ tracker->AddAssociatedRecord(record.get());
+ answers.push_back(std::move(record));
+ }
+
+ EXPECT_CALL(sender_, SendMulticast(_))
+ .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)),
+ WithArgs<0>(VerifyRecordCount(49)),
+ Return(Error::None())))
+ .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)),
+ WithArgs<0>(VerifyRecordCount(50)),
+ Return(Error::None())))
+ .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(false)),
+ WithArgs<0>(VerifyRecordCount(1)),
+ Return(Error::None())));
+
+ clock_.Advance(std::chrono::milliseconds(120));
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc
index 58037f7e9a9..fa79dc5d44c 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc
@@ -6,6 +6,7 @@
#include "absl/hash/hash.h"
#include "absl/strings/ascii.h"
+#include "util/hashing.h"
#include "util/logging.h"
namespace openscreen {
@@ -14,25 +15,14 @@ namespace discovery {
namespace {
std::vector<uint64_t> ComputeDomainNameSubhashes(const DomainName& name) {
- // Based on absl Hash128to64 that combines two 64-bit hashes into one
- auto hash_combiner = [](uint64_t seed, const std::string& value) -> uint64_t {
- static const uint64_t kMultiplier = UINT64_C(0x9ddfea08eb382d69);
- const uint64_t hash_value = absl::Hash<std::string>{}(value);
- uint64_t a = (hash_value ^ seed) * kMultiplier;
- a ^= (a >> 47);
- uint64_t b = (seed ^ a) * kMultiplier;
- b ^= (b >> 47);
- b *= kMultiplier;
- return b;
- };
-
const std::vector<std::string>& labels = name.labels();
// Use a large prime between 2^63 and 2^64 as a starting value.
// This is taken from absl::Hash implementation.
uint64_t hash_value = UINT64_C(0xc3a5c85c97cb3127);
std::vector<uint64_t> subhashes(labels.size());
for (size_t i = labels.size(); i-- > 0;) {
- hash_value = hash_combiner(hash_value, absl::AsciiStrToLower(labels[i]));
+ hash_value =
+ ComputeAggregateHash(hash_value, absl::AsciiStrToLower(labels[i]));
subhashes[i] = hash_value;
}
return subhashes;
@@ -201,6 +191,17 @@ bool MdnsWriter::Write(const TxtRecordRdata& rdata) {
return true;
}
+bool MdnsWriter::Write(const NsecRecordRdata& rdata) {
+ Cursor cursor(this);
+ if (Skip(sizeof(uint16_t)) && Write(rdata.next_domain_name()) &&
+ Write(rdata.encoded_types()) &&
+ UpdateRecordLength(current(), cursor.origin())) {
+ cursor.Commit();
+ return true;
+ }
+ return false;
+}
+
bool MdnsWriter::Write(const MdnsRecord& record) {
Cursor cursor(this);
if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) &&
@@ -229,7 +230,7 @@ bool MdnsWriter::Write(const MdnsMessage& message) {
Cursor cursor(this);
Header header;
header.id = message.id();
- header.flags = MakeFlags(message.type());
+ header.flags = MakeFlags(message.type(), message.is_truncated());
header.question_count = message.questions().size();
header.answer_count = message.answers().size();
header.authority_record_count = message.authority_records().size();
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h
index 66c1b036126..3b8a3b08026 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h
@@ -13,7 +13,7 @@
namespace openscreen {
namespace discovery {
-class MdnsWriter : public openscreen::BigEndianWriter {
+class MdnsWriter : public BigEndianWriter {
public:
using BigEndianWriter::BigEndianWriter;
using BigEndianWriter::Write;
@@ -31,6 +31,7 @@ class MdnsWriter : public openscreen::BigEndianWriter {
bool Write(const AAAARecordRdata& rdata);
bool Write(const PtrRecordRdata& rdata);
bool Write(const TxtRecordRdata& rdata);
+ bool Write(const NsecRecordRdata& rdata);
// Writes a DNS resource record with its RDATA.
// The correct type of RDATA to be written is contained in the type
// specified in the record.
@@ -59,7 +60,7 @@ class MdnsWriter : public openscreen::BigEndianWriter {
// Domain name compression dictionary.
// Maps hashes of previously written domain (sub)names
- // to the label pointers of the first occurences in the underlying buffer.
+ // to the label pointers of the first occurrences in the underlying buffer.
// Compression of multiple domain names is supported on the same instance of
// the MdnsWriter. Underlying buffer may contain other data in addition to the
// domain names. The compression dictionary persists between calls to
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc
index ad0c2464060..03ddbe8e50f 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer_unittest.cc
@@ -5,6 +5,7 @@
#include "discovery/mdns/mdns_writer.h"
#include <memory>
+#include <vector>
#include "discovery/mdns/testing/mdns_test_util.h"
#include "gmock/gmock.h"
@@ -244,15 +245,55 @@ TEST(MdnsWriterTest, WriteAAAARecordRdata) {
TEST(MdnsWriterTest, WriteAAAARecordRdata_InsufficientBuffer) {
// clang-format off
- constexpr uint8_t kAAAARdata[] = {
+ constexpr uint16_t kAAAARdata[] = {
// ADDRESS = FE80:0000:0000:0000:0202:B3FF:FE1E:8329
- 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29,
+ 0xfe80, 0x0000, 0x0000, 0x0000,
+ 0x0202, 0xb3ff, 0xfe1e, 0x8329,
};
// clang-format on
TestWriteEntryInsufficientBuffer(AAAARecordRdata(IPAddress(kAAAARdata)));
}
+TEST(MdnsWriterTest, WriteNSECRecordRdata) {
+ const DomainName domain{"testing", "local"};
+ NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA,
+ DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC);
+
+ // clang-format off
+ constexpr uint8_t kExpectedRdata[] = {
+ 0x00, 0x20, // RDLENGTH = 32
+ 0x08, 'm', 'y', 'd', 'e', 'v', 'i', 'c', 'e',
+ 0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ // It takes 8 bytes to encode the kA and kSRV records because:
+ // - Both record types have value less than 256, so they are both in window
+ // block 1.
+ // - The bitmap length for this block is always a single byte
+ // - DnsTypes have the following values:
+ // - kA = 1 (encoded in byte 1)
+ // kTXT = 16 (encoded in byte 3)
+ // - kSRV = 33 (encoded in byte 5)
+ // - kNSEC = 47 (encoded in 6 bytes)
+ // - The largest of these is 47, so 6 bytes are needed to encode this data.
+ // So the full encoded version is:
+ // 00000000 00000110 01000000 00000000 10000000 00000000 0100000 00000001
+ // |window| | size | | 0-7 | | 8-15 | |16-23 | |24-31 | |32-39 | |40-47 |
+ 0x00, 0x06, 0x40, 0x00, 0x80, 0x00, 0x40, 0x01
+ };
+ // clang-format on
+ TestWriteEntrySucceeds(
+ NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA,
+ DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC),
+ kExpectedRdata, sizeof(kExpectedRdata));
+}
+
+TEST(MdnsWriterTest, WriteNSECRecordRdata_InsufficientBuffer) {
+ TestWriteEntryInsufficientBuffer(
+ NsecRecordRdata(DomainName{"mydevice", "testing", "local"}, DnsType::kA,
+ DnsType::kTXT, DnsType::kSRV, DnsType::kNSEC));
+}
+
TEST(MdnsWriterTest, WritePtrRecordRdata) {
// clang-format off
constexpr uint8_t kExpectedRdata[] = {
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h
index 30847e42643..e80e68268b8 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h
@@ -16,6 +16,13 @@
#include <stddef.h>
#include <stdint.h>
+#include <array>
+#include <chrono>
+
+#include "platform/api/time.h"
+#include "platform/base/ip_address.h"
+#include "util/logging.h"
+
namespace openscreen {
namespace discovery {
@@ -28,18 +35,25 @@ namespace discovery {
// RFC 5771: https://www.ietf.org/rfc/rfc5771.txt
// RFC 7346: https://www.ietf.org/rfc/rfc7346.txt
-// IPv4 group address for joining mDNS multicast group, given as byte array in
-// network-order. This is a link-local multicast address, so messages will not
-// be forwarded outside local network. See RFC 6762, section 3.
-constexpr uint8_t kDefaultMulticastGroupIPv4[4] = {224, 0, 0, 251};
+// Default multicast port used by mDNS protocol. On some systems there may be
+// multiple processes binding to same port, so prefer to allow address re-use.
+// See RFC 6762, Section 2
+constexpr uint16_t kDefaultMulticastPort = 5353;
-// IPv6 group address for joining mDNS multicast group, given as byte array in
+// IPv4 group address for joining mDNS multicast group, given as byte array in
// network-order. This is a link-local multicast address, so messages will not
// be forwarded outside local network. See RFC 6762, section 3.
-constexpr uint8_t kDefaultMulticastGroupIPv6[16] = {
- 0xFF, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFB,
+const IPAddress kDefaultMulticastGroupIPv4{224, 0, 0, 251};
+const IPEndpoint kDefaultMulticastGroupIPv4Endpoint{{}, kDefaultMulticastPort};
+
+// IPv6 group address for joining mDNS multicast group. This is a link-local
+// multicast address, so messages will not be forwarded outside local network.
+// See RFC 6762, section 3.
+const IPAddress kDefaultMulticastGroupIPv6{
+ 0xFF02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00FB,
};
+const IPEndpoint kDefaultMulticastGroupIPv6Endpoint{{0, 0, 0, 0, 0, 0, 0, 0},
+ kDefaultMulticastPort};
// IPv4 group address for joining cast-specific site-local mDNS multicast group,
// given as byte array in network-order. This is a site-local multicast address,
@@ -57,7 +71,9 @@ constexpr uint8_t kDefaultMulticastGroupIPv6[16] = {
// NOTE: For now the group address is the same group address used for SSDP
// discovery, albeit using the MDNS port rather than SSDP port.
-constexpr uint8_t kDefaultSiteLocalGroupIPv4[4] = {239, 255, 255, 250};
+const IPAddress kDefaultSiteLocalGroupIPv4{239, 255, 255, 250};
+const IPEndpoint kDefaultSiteLocalGroupIPv4Endpoint{kDefaultSiteLocalGroupIPv4,
+ kDefaultMulticastPort};
// IPv6 group address for joining cast-specific site-local mDNS multicast group,
// give as byte array in network-order. See comments for IPv4 group address for
@@ -65,15 +81,11 @@ constexpr uint8_t kDefaultSiteLocalGroupIPv4[4] = {239, 255, 255, 250};
// 0xFF05 is site-local. See RFC 7346.
// FF0X:0:0:0:0:0:0:C is variable scope multicast addresses for SSDP. See
// https://www.iana.org/assignments/ipv6-multicast-addresses
-constexpr uint8_t kDefaultSiteLocalGroupIPv6[16] = {
- 0xFF, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C,
+const IPAddress kDefaultSiteLocalGroupIPv6{
+ 0xFF05, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x000C,
};
-
-// Default multicast port used by mDNS protocol. On some systems there may be
-// multiple processes binding to same port, so prefer to allow address re-use.
-// See RFC 6762, Section 2
-constexpr uint16_t kDefaultMulticastPort = 5353;
+const IPEndpoint kDefaultSiteLocalGroupIPv6Endpoint{kDefaultSiteLocalGroupIPv6,
+ kDefaultMulticastPort};
// Maximum MTU size (1500) minus the UDP header size (8) and IP header size
// (20). If any packets are larger than this size, the responder or sender
@@ -149,9 +161,18 @@ constexpr MessageType GetMessageType(uint16_t flags) {
return (flags & kFlagResponse) ? MessageType::Response : MessageType::Query;
}
-constexpr uint16_t MakeFlags(MessageType type) {
+constexpr bool IsMessageTruncated(uint16_t flags) {
+ return flags & kFlagTC;
+}
+
+constexpr uint16_t MakeFlags(MessageType type, bool is_truncated) {
// RFC 6762 Section 18.2 and Section 18.4
- return (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0;
+ uint16_t flags =
+ (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0;
+ if (is_truncated) {
+ flags |= kFlagTC;
+ }
+ return flags;
}
constexpr bool IsValidFlagsSection(uint16_t flags) {
@@ -282,9 +303,36 @@ enum class DnsType : uint16_t {
kTXT = 16,
kAAAA = 28,
kSRV = 33,
+ kNSEC = 47,
kANY = 255, // Only allowed for QTYPE
};
+inline std::ostream& operator<<(std::ostream& output, DnsType type) {
+ switch (type) {
+ case DnsType::kA:
+ return output << "A";
+ case DnsType::kPTR:
+ return output << "PTR";
+ case DnsType::kTXT:
+ return output << "TXT";
+ case DnsType::kAAAA:
+ return output << "AAAA";
+ case DnsType::kSRV:
+ return output << "SRV";
+ case DnsType::kNSEC:
+ return output << "NSEC";
+ case DnsType::kANY:
+ return output << "ANY";
+ }
+
+ OSP_NOTREACHED();
+ return output;
+}
+
+constexpr std::array<DnsType, 7> kSupportedDnsTypes = {
+ DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA,
+ DnsType::kSRV, DnsType::kNSEC, DnsType::kANY};
+
enum class DnsClass : uint16_t {
kIN = 1,
kANY = 255, // Only allowed for QCLASS
@@ -305,6 +353,14 @@ enum class ResponseType {
kUnicast = 1,
};
+// These are the default TTL values for supported DNS Record types as specified
+// by RFC 6762 section 10.
+constexpr std::chrono::seconds kPtrRecordTtl(120);
+constexpr std::chrono::seconds kSrvRecordTtl(120);
+constexpr std::chrono::seconds kARecordTtl(120);
+constexpr std::chrono::seconds kAAAARecordTtl(120);
+constexpr std::chrono::seconds kTXTRecordTtl(120);
+
// DNS CLASS masks and values.
// In mDNS the most significant bit of the RRCLASS for response records is
// designated as the "cache-flush bit", as described in
@@ -366,6 +422,19 @@ constexpr size_t kTXTMaxEntrySize = 255;
// See RFC: https://tools.ietf.org/html/rfc6763#section-6.1
constexpr uint8_t kTXTEmptyRdata = 0;
+// ============================================================================
+// Probing Constants
+// ============================================================================
+
+// RFC 6762 section 8.1 specifies that a probe should wait 250 ms between
+// subsequent probe queries.
+constexpr Clock::duration kDelayBetweenProbeQueries =
+ std::chrono::duration_cast<Clock::duration>(std::chrono::milliseconds{250});
+
+// RFC 6762 section 8.1 specifies that the probing phase should send out probe
+// requests 3 times before treating the probe as completed.
+constexpr int kProbeIterationCountBeforeSuccess = 3;
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc
new file mode 100644
index 00000000000..89c079b0190
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.cc
@@ -0,0 +1,15 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/mdns/public/mdns_service.h"
+
+namespace openscreen {
+namespace discovery {
+
+MdnsService::MdnsService() = default;
+
+MdnsService::~MdnsService() = default;
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h
index 817e57dc87d..b32deb780e7 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_service.h
@@ -5,44 +5,107 @@
#ifndef DISCOVERY_MDNS_PUBLIC_MDNS_SERVICE_H_
#define DISCOVERY_MDNS_PUBLIC_MDNS_SERVICE_H_
+#include <functional>
#include <memory>
#include "discovery/mdns/public/mdns_constants.h"
+#include "platform/base/error.h"
+#include "platform/base/interface_info.h"
+#include "platform/base/ip_address.h"
namespace openscreen {
-struct IPEndpoint;
class TaskRunner;
namespace discovery {
+struct Config;
class DomainName;
+class MdnsDomainConfirmedProvider;
class MdnsRecord;
class MdnsRecordChangedCallback;
+class ReportingClient;
class MdnsService {
public:
- virtual ~MdnsService() = default;
+ enum SupportedNetworkAddressFamily : uint8_t {
+ kNoAddressFamily = 0,
+ kUseIpV4Multicast = 0x01 << 0,
+ kUseIpV6Multicast = 0x01 << 1
+ };
+
+ MdnsService();
+ virtual ~MdnsService();
// Creates a new MdnsService instance, to be owned by the caller. On failure,
- // returns nullptr.
- static std::unique_ptr<MdnsService> Create(TaskRunner* task_runner);
+ // returns nullptr. |task_runner|, |reporting_client|, and |config| must exist
+ // for the duration of the resulting instance's life.
+ static std::unique_ptr<MdnsService> Create(
+ TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config,
+ NetworkInterfaceIndex network_interface,
+ SupportedNetworkAddressFamily supported_address_types);
+ // Starts an mDNS query with the given properties. Updated records are passed
+ // to |callback|. The caller must ensure |callback| remains alive while it is
+ // registered with a query.
virtual void StartQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) = 0;
+ // Stops an mDNS query with the given properties. |callback| must be the same
+ // callback pointer that was previously passed to StartQuery.
virtual void StopQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) = 0;
- virtual void RegisterRecord(const MdnsRecord& record) = 0;
+ // Re-initializes the process of service discovery for the provided domain
+ // name. All ongoing queries for this domain are restarted and any previously
+ // received query results are discarded.
+ virtual void ReinitializeQueries(const DomainName& name) = 0;
+
+ // Starts probing for a valid domain name based on the given one. |callback|
+ // will be called once a valid domain is found, and the instance must persist
+ // until that call is received.
+ virtual Error StartProbe(MdnsDomainConfirmedProvider* callback,
+ DomainName requested_name,
+ IPAddress address) = 0;
+
+ // Registers a new mDNS record for advertisement by this service. For A, AAAA,
+ // SRV, and TXT records, the domain name must have already been claimed by the
+ // ClaimExclusiveOwnership() method and for PTR records the name being pointed
+ // to must have been claimed in the same fashion, but the domain name in the
+ // top-level MdnsRecord entity does not.
+ virtual Error RegisterRecord(const MdnsRecord& record) = 0;
- virtual void DeregisterRecord(const MdnsRecord& record) = 0;
+ // Updates the existing record with name matching the name of the new record.
+ // NOTE: This method is not valid for PTR records.
+ virtual Error UpdateRegisteredRecord(const MdnsRecord& old_record,
+ const MdnsRecord& new_record) = 0;
+
+ // Stops advertising the provided record. If no more records with the provided
+ // name are bing advertised after this call's completion, then ownership of
+ // the name is released.
+ virtual Error UnregisterRecord(const MdnsRecord& record) = 0;
};
+inline MdnsService::SupportedNetworkAddressFamily operator&(
+ MdnsService::SupportedNetworkAddressFamily lhs,
+ MdnsService::SupportedNetworkAddressFamily rhs) {
+ return static_cast<MdnsService::SupportedNetworkAddressFamily>(
+ static_cast<uint8_t>(lhs) & static_cast<uint8_t>(rhs));
+}
+
+inline MdnsService::SupportedNetworkAddressFamily operator|(
+ MdnsService::SupportedNetworkAddressFamily lhs,
+ MdnsService::SupportedNetworkAddressFamily rhs) {
+ return static_cast<MdnsService::SupportedNetworkAddressFamily>(
+ static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs));
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS b/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS
new file mode 100644
index 00000000000..efbb593a6d5
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/DEPS
@@ -0,0 +1,5 @@
+# -*- Mode: Python; -*-
+
+include_rules = [
+ '+discovery/mdns',
+]
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc
index 6d99fced114..6b154174a62 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc
+++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc
@@ -18,5 +18,37 @@ TxtRecordRdata MakeTxtRecord(std::initializer_list<absl::string_view> strings) {
return TxtRecordRdata(std::move(texts));
}
+MdnsRecord GetFakePtrRecord(const DomainName& target,
+ std::chrono::seconds ttl) {
+ DomainName name(++target.labels().begin(), target.labels().end());
+ PtrRecordRdata rdata(target);
+ return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN,
+ RecordType::kShared, ttl, rdata);
+}
+
+MdnsRecord GetFakeSrvRecord(const DomainName& name, std::chrono::seconds ttl) {
+ SrvRecordRdata rdata(0, 0, 80, name);
+ return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
+ ttl, rdata);
+}
+
+MdnsRecord GetFakeTxtRecord(const DomainName& name, std::chrono::seconds ttl) {
+ TxtRecordRdata rdata;
+ return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
+ ttl, rdata);
+}
+
+MdnsRecord GetFakeARecord(const DomainName& name, std::chrono::seconds ttl) {
+ ARecordRdata rdata(IPAddress(192, 168, 0, 0));
+ return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique, ttl,
+ rdata);
+}
+
+MdnsRecord GetFakeAAAARecord(const DomainName& name, std::chrono::seconds ttl) {
+ AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8));
+ return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
+ ttl, rdata);
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h
index 0ccf3eeba71..419e1f06d9b 100644
--- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h
+++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h
@@ -15,6 +15,19 @@ namespace discovery {
TxtRecordRdata MakeTxtRecord(std::initializer_list<absl::string_view> strings);
+// Methods to create fake MdnsRecord entities for use in UnitTests.
+MdnsRecord GetFakePtrRecord(const DomainName& target,
+ std::chrono::seconds ttl = std::chrono::seconds(1));
+MdnsRecord GetFakeSrvRecord(const DomainName& name,
+ std::chrono::seconds ttl = std::chrono::seconds(1));
+MdnsRecord GetFakeTxtRecord(const DomainName& name,
+ std::chrono::seconds ttl = std::chrono::seconds(1));
+MdnsRecord GetFakeARecord(const DomainName& name,
+ std::chrono::seconds ttl = std::chrono::seconds(1));
+MdnsRecord GetFakeAAAARecord(
+ const DomainName& name,
+ std::chrono::seconds ttl = std::chrono::seconds(1));
+
} // namespace discovery
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/discovery/public/DEPS b/chromium/third_party/openscreen/src/discovery/public/DEPS
new file mode 100644
index 00000000000..eb84cf6bcff
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/public/DEPS
@@ -0,0 +1,5 @@
+# -*- Mode: Python; -*-
+
+include_rules = [
+ '+discovery/dnssd/public',
+]
diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h
new file mode 100644
index 00000000000..e2dcb8361ab
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_factory.h
@@ -0,0 +1,28 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_
+#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_
+
+#include "discovery/dnssd/public/dns_sd_service.h"
+#include "util/serial_delete_ptr.h"
+
+namespace openscreen {
+
+class TaskRunner;
+
+namespace discovery {
+
+struct Config;
+class ReportingClient;
+
+SerialDeletePtr<DnsSdService> CreateDnsSdService(
+ TaskRunner* task_runner,
+ ReportingClient* reporting_client,
+ const Config& config);
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_FACTORY_H_
diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h
new file mode 100644
index 00000000000..f94d4e4bc4e
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_publisher.h
@@ -0,0 +1,93 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_
+#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_
+
+#include <string>
+
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
+#include "discovery/dnssd/public/dns_sd_publisher.h"
+#include "discovery/dnssd/public/dns_sd_service.h"
+#include "platform/base/error.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace discovery {
+
+// This class represents a top-level discovery API which sits on top of DNS-SD.
+// The main purpose of this class is to hide DNS-SD internals from embedders who
+// do not care about the specific functionality and do not need to understand
+// DNS-SD Internals.
+// T is the service-specific type which stores information regarding a specific
+// service instance.
+// NOTE: This class is not thread-safe and calls will be made to DnsSdService in
+// the same sequence and on the same threads from which these methods are
+// called. This is to avoid forcing design decisions on embedders who write
+// their own implementations of the DNS-SD layer.
+template <typename T>
+class DnsSdServicePublisher : public DnsSdPublisher::Client {
+ public:
+ // This function type is responsible for converting from a T type to a
+ // DNS service instance (to publish to the network).
+ using ServiceInstanceConverter = std::function<DnsSdInstanceRecord(const T&)>;
+
+ DnsSdServicePublisher(DnsSdService* service,
+ std::string service_name,
+ ServiceInstanceConverter conversion)
+ : conversion_(conversion),
+ service_name_(std::move(service_name)),
+ publisher_(service ? service->GetPublisher() : nullptr) {
+ OSP_DCHECK(publisher_);
+ }
+
+ ~DnsSdServicePublisher() = default;
+
+ Error Register(const T& instance) {
+ if (!instance.IsValid()) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ DnsSdInstanceRecord record = conversion_(instance);
+ return publisher_->Register(record, this);
+ }
+
+ Error UpdateRegistration(const T& instance) {
+ if (!instance.IsValid()) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ DnsSdInstanceRecord record = conversion_(instance);
+ return publisher_->UpdateRegistration(record);
+ }
+
+ ErrorOr<int> DeregisterAll() {
+ return publisher_->DeregisterAll(service_name_);
+ }
+
+ protected:
+ // DnsSdPublisher::Client overrides.
+ //
+ // Embedders who care about the instance id with which the service was
+ // published may override this method.
+ void OnInstanceClaimed(const DnsSdInstanceRecord& requested_record,
+ const DnsSdInstanceRecord& claimed_record) {
+ OSP_DVLOG << "Instance ID '" << claimed_record.instance_id()
+ << "' claimed for requested ID '"
+ << requested_record.instance_id() << "'";
+ OnInstanceClaimed(requested_record.instance_id());
+ }
+
+ virtual void OnInstanceClaimed(const std::string& requested_instance_id) {}
+
+ private:
+ ServiceInstanceConverter conversion_;
+ std::string service_name_;
+ DnsSdPublisher* const publisher_;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_PUBLISHER_H_
diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h
new file mode 100644
index 00000000000..56c9a51b321
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h
@@ -0,0 +1,215 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
+#define DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "discovery/dnssd/public/dns_sd_instance_record.h"
+#include "discovery/dnssd/public/dns_sd_querier.h"
+#include "discovery/dnssd/public/dns_sd_service.h"
+#include "platform/base/error.h"
+#include "util/logging.h"
+
+namespace openscreen {
+namespace discovery {
+namespace {
+
+// NOTE: Must be inlined to avoid compilation failure for unused function when
+// DLOGs are disabled.
+template <typename T>
+inline std::string GetInstanceNames(
+ const std::unordered_map<std::string, T>& map) {
+ std::string s;
+ auto it = map.begin();
+ if (it == map.end()) {
+ return s;
+ }
+
+ s += it->first;
+ while (++it != map.end()) {
+ s += ", " + it->first;
+ }
+ return s;
+}
+
+} // namespace
+
+// This class represents a top-level discovery API which sits on top of DNS-SD.
+// T is the service-specific type which stores information regarding a specific
+// service instance.
+// TODO(rwkeane): Include reporting client as ctor parameter once parallel CLs
+// are in.
+// NOTE: This class is not thread-safe and calls will be made to DnsSdService in
+// the same sequence and on the same threads from which these methods are
+// called. This is to avoid forcing design decisions on embedders who write
+// their own implementations of the DNS-SD layer.
+template <typename T>
+class DnsSdServiceWatcher : public DnsSdQuerier::Callback {
+ public:
+ using ConstRefT = std::reference_wrapper<const T>;
+
+ // The method which will be called when any new service instance is
+ // discovered, a service instance changes its data (such as TXT or A data), or
+ // a previously discovered service instance ceases to be available. The vector
+ // is the set of all currently active service instances which have been
+ // discovered so far.
+ using ServicesUpdatedCallback =
+ std::function<void(std::vector<ConstRefT> services)>;
+
+ // This function type is responsible for converting from a DNS service
+ // instance (received from another mDNS endpoint) to a T type to be returned
+ // to the caller.
+ using ServiceConverter =
+ std::function<ErrorOr<T>(const DnsSdInstanceRecord&)>;
+
+ DnsSdServiceWatcher(DnsSdService* service,
+ std::string service_name,
+ ServiceConverter conversion,
+ ServicesUpdatedCallback callback)
+ : conversion_(conversion),
+ service_name_(std::move(service_name)),
+ callback_(std::move(callback)),
+ querier_(service ? service->GetQuerier() : nullptr) {
+ OSP_DCHECK(querier_);
+ }
+
+ ~DnsSdServiceWatcher() = default;
+
+ // Starts service discovery.
+ void StartDiscovery() {
+ OSP_DCHECK(!is_running_);
+ is_running_ = true;
+
+ querier_->StartQuery(service_name_, this);
+ }
+
+ // Stops service discovery.
+ void StopDiscovery() {
+ OSP_DCHECK(is_running_);
+ is_running_ = false;
+
+ querier_->StopQuery(service_name_, this);
+ }
+
+ // Returns whether or not discovery is currently ongoing.
+ bool is_running() const { return is_running_; }
+
+ // Re-initializes the process of service discovery, even if the underlying
+ // implementation would not normally do so at this time. All previously
+ // received service data is discarded.
+ // NOTE: This call will return an error if StartDiscovery has not yet been
+ // called.
+ Error ForceRefresh() {
+ if (!is_running_) {
+ return Error::Code::kOperationInvalid;
+ }
+
+ querier_->ReinitializeQueries(service_name_);
+ records_.clear();
+ return Error::None();
+ }
+
+ // Re-initializes the process of service discovery, even if the underlying
+ // implementation would not normally do so at this time. All previously
+ // received service data is persisted.
+ // NOTE: This call will return an error if StartDiscovery has not yet been
+ // called.
+ Error DiscoverNow() {
+ if (!is_running_) {
+ return Error::Code::kOperationInvalid;
+ }
+
+ querier_->ReinitializeQueries(service_name_);
+ return Error::None();
+ }
+
+ // Returns the set of services which have been received so far.
+ std::vector<ConstRefT> GetServices() const {
+ std::vector<ConstRefT> refs;
+ for (const auto& pair : records_) {
+ refs.push_back(*pair.second.get());
+ }
+
+ OSP_DVLOG << "Currently " << records_.size()
+ << " known service instances: [" << GetInstanceNames(records_)
+ << "]";
+
+ return refs;
+ }
+
+ private:
+ friend class TestServiceWatcher;
+
+ // DnsSdQuerier::Callback overrides.
+ void OnInstanceCreated(const DnsSdInstanceRecord& new_record) override {
+ // NOTE: existence is not checked because records may be overwritten after
+ // querier_->ReinitializeQueries() is called.
+ ErrorOr<T> record = conversion_(new_record);
+ if (record.is_error()) {
+ OSP_LOG << "Conversion of received record failed with error: "
+ << record.error();
+ return;
+ }
+ records_[new_record.instance_id()] =
+ std::make_unique<T>(std::move(record.value()));
+ callback_(GetServices());
+ }
+
+ void OnInstanceUpdated(const DnsSdInstanceRecord& modified_record) override {
+ auto it = records_.find(modified_record.instance_id());
+ if (it != records_.end()) {
+ ErrorOr<T> record = conversion_(modified_record);
+ if (record.is_error()) {
+ OSP_LOG << "Conversion of received record failed with error: "
+ << record.error();
+ return;
+ }
+ auto ptr = std::make_unique<T>(std::move(record.value()));
+ it->second.swap(ptr);
+
+ callback_(GetServices());
+ } else {
+ OSP_LOG << "Received modified record for non-existent DNS-SD Instance "
+ << modified_record.instance_id();
+ }
+ }
+
+ void OnInstanceDeleted(const DnsSdInstanceRecord& old_record) override {
+ if (records_.erase(old_record.instance_id())) {
+ callback_(GetServices());
+ } else {
+ OSP_LOG << "Received deletion of record for non-existent DNS-SD Instance "
+ << old_record.instance_id();
+ }
+ }
+
+ // Set of all instance ids found so far, mapped to the T type that it
+ // represents. unique_ptr<T> entities are used so that the const refs returned
+ // from GetServices() and the ServicesUpdatedCallback can persist even once
+ // this map is resized. NOTE: Unordered map is used because this set is in
+ // many cases expected to be large.
+ std::unordered_map<std::string, std::unique_ptr<T>> records_;
+
+ // Represents whether discovery is currently running or not.
+ bool is_running_ = false;
+
+ // Converts from the DNS-SD representation of a service to the outside
+ // representation.
+ ServiceConverter conversion_;
+
+ std::string service_name_;
+ ServicesUpdatedCallback callback_;
+ DnsSdQuerier* const querier_;
+};
+
+} // namespace discovery
+} // namespace openscreen
+
+#endif // DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc
new file mode 100644
index 00000000000..a22558f5cd7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc
@@ -0,0 +1,335 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "discovery/public/dns_sd_service_watcher.h"
+
+#include <algorithm>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::_;
+using testing::ContainerEq;
+using testing::IsSubsetOf;
+using testing::IsSupersetOf;
+using testing::StrictMock;
+
+namespace openscreen {
+namespace discovery {
+namespace {
+
+std::vector<std::string> ConvertRefs(
+ const std::vector<std::reference_wrapper<const std::string>>& value) {
+ std::vector<std::string> strings;
+
+ // This loop is required to unwrap reference_wrapper objects.
+ for (const std::string& val : value) {
+ strings.push_back(val);
+ }
+ return strings;
+}
+
+static const IPAddress kAddressV4(192, 168, 0, 0);
+static const IPEndpoint kEndpointV4{kAddressV4, 0};
+static const std::string kCastServiceId = "_googlecast._tcp";
+static const std::string kCastDomainId = "local";
+
+class MockDnsSdService : public DnsSdService {
+ public:
+ MockDnsSdService() : querier_(this) {}
+
+ DnsSdQuerier* GetQuerier() override { return &querier_; }
+ DnsSdPublisher* GetPublisher() override { return nullptr; }
+
+ MOCK_METHOD2(StartQuery,
+ void(const std::string& service, DnsSdQuerier::Callback* cb));
+ MOCK_METHOD2(StopQuery,
+ void(const std::string& service, DnsSdQuerier::Callback* cb));
+ MOCK_METHOD1(ReinitializeQueries, void(const std::string& service));
+
+ private:
+ class MockQuerier : public DnsSdQuerier {
+ public:
+ explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) {
+ OSP_DCHECK(service);
+ }
+
+ void StartQuery(const std::string& service,
+ DnsSdQuerier::Callback* cb) override {
+ mock_service_->StartQuery(service, cb);
+ }
+
+ void StopQuery(const std::string& service,
+ DnsSdQuerier::Callback* cb) override {
+ mock_service_->StopQuery(service, cb);
+ }
+
+ void ReinitializeQueries(const std::string& service) override {
+ mock_service_->ReinitializeQueries(service);
+ }
+
+ private:
+ MockDnsSdService* const mock_service_;
+ };
+
+ MockQuerier querier_;
+};
+
+} // namespace
+
+class TestServiceWatcher : public DnsSdServiceWatcher<std::string> {
+ public:
+ using DnsSdServiceWatcher<std::string>::ConstRefT;
+
+ explicit TestServiceWatcher(MockDnsSdService* service)
+ : DnsSdServiceWatcher<std::string>(
+ service,
+ kCastServiceId,
+ [this](const DnsSdInstanceRecord& record) {
+ return Convert(record);
+ },
+ [this](std::vector<ConstRefT> ref) { Callback(std::move(ref)); }) {}
+
+ MOCK_METHOD1(Callback, void(std::vector<ConstRefT>));
+
+ using DnsSdServiceWatcher<std::string>::OnInstanceCreated;
+ using DnsSdServiceWatcher<std::string>::OnInstanceUpdated;
+ using DnsSdServiceWatcher<std::string>::OnInstanceDeleted;
+
+ private:
+ std::string Convert(const DnsSdInstanceRecord& record) {
+ return record.instance_id();
+ }
+};
+
+class DnsSdServiceWatcherTests : public testing::Test {
+ public:
+ DnsSdServiceWatcherTests() : watcher_(&service_) {
+ // Start service discovery, since all other tests need it
+ EXPECT_FALSE(watcher_.is_running());
+ EXPECT_CALL(service_, StartQuery(kCastServiceId, _));
+ watcher_.StartDiscovery();
+ testing::Mock::VerifyAndClearExpectations(&service_);
+ }
+
+ protected:
+ void CreateNewInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ std::vector<std::string> callbacked_services;
+ EXPECT_CALL(watcher_, Callback(_))
+ .WillOnce([services = &callbacked_services](
+ std::vector<TestServiceWatcher::ConstRefT> value) {
+ *services = ConvertRefs(value);
+ });
+ watcher_.OnInstanceCreated(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count + 1);
+
+ EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
+ EXPECT_THAT(fetched_services, IsSupersetOf(services_before));
+ }
+
+ void CreateExistingInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ std::vector<std::string> callbacked_services;
+ EXPECT_CALL(watcher_, Callback(_))
+ .WillOnce([services = &callbacked_services](
+ std::vector<TestServiceWatcher::ConstRefT> value) {
+ *services = ConvertRefs(value);
+ });
+ watcher_.OnInstanceCreated(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ const std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count);
+
+ EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
+ EXPECT_THAT(fetched_services, ContainerEq(services_before));
+ }
+
+ void UpdateExistingInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ std::vector<std::string> callbacked_services;
+ EXPECT_CALL(watcher_, Callback(_))
+ .WillOnce([services = &callbacked_services](
+ std::vector<TestServiceWatcher::ConstRefT> value) {
+ *services = ConvertRefs(value);
+ });
+ watcher_.OnInstanceUpdated(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ const std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count);
+
+ EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
+ EXPECT_THAT(fetched_services, ContainerEq(services_before));
+ }
+
+ void DeleteExistingInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ std::vector<std::string> callbacked_services;
+ EXPECT_CALL(watcher_, Callback(_))
+ .WillOnce([services = &callbacked_services](
+ std::vector<TestServiceWatcher::ConstRefT> value) {
+ *services = ConvertRefs(value);
+ });
+ watcher_.OnInstanceDeleted(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ const std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count - 1);
+ }
+
+ void UpdateNonExistingInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ EXPECT_CALL(watcher_, Callback(_)).Times(0);
+ watcher_.OnInstanceUpdated(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ const std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count);
+
+ EXPECT_THAT(services_before, ContainerEq(fetched_services));
+ }
+
+ void DeleteNonExistingInstance(const DnsSdInstanceRecord& record) {
+ const std::vector<std::string> services_before =
+ ConvertRefs(watcher_.GetServices());
+ const size_t count = services_before.size();
+
+ EXPECT_CALL(watcher_, Callback(_)).Times(0);
+ watcher_.OnInstanceDeleted(record);
+ testing::Mock::VerifyAndClearExpectations(&watcher_);
+
+ const std::vector<std::string> fetched_services =
+ ConvertRefs(watcher_.GetServices());
+ EXPECT_EQ(fetched_services.size(), count);
+
+ EXPECT_THAT(services_before, ContainerEq(fetched_services));
+ }
+
+ bool ContainsService(const DnsSdInstanceRecord& record) {
+ const std::string& service = record.instance_id();
+ const std::vector<TestServiceWatcher::ConstRefT> services =
+ watcher_.GetServices();
+ return std::find_if(services.begin(), services.end(),
+ [&service](const std::string& ref) {
+ return service == ref;
+ }) != services.end();
+ }
+
+ StrictMock<MockDnsSdService> service_;
+ StrictMock<TestServiceWatcher> watcher_;
+ std::vector<std::string> fetched_services;
+};
+
+TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) {
+ EXPECT_TRUE(watcher_.is_running());
+ EXPECT_CALL(service_, StopQuery(kCastServiceId, _));
+ watcher_.StopDiscovery();
+ EXPECT_FALSE(watcher_.is_running());
+}
+
+TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) {
+ StrictMock<MockDnsSdService> service;
+ StrictMock<TestServiceWatcher> watcher(&service);
+ EXPECT_FALSE(watcher.DiscoverNow().ok());
+ EXPECT_FALSE(watcher.ForceRefresh().ok());
+}
+
+TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) {
+ const DnsSdInstanceRecord record("Instance", kCastServiceId, kCastDomainId,
+ kEndpointV4, DnsSdTxtRecord{});
+ CreateNewInstance(record);
+
+ // Refresh services.
+ EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
+ EXPECT_TRUE(watcher_.DiscoverNow().ok());
+ EXPECT_EQ(watcher_.GetServices().size(), size_t{1});
+ testing::Mock::VerifyAndClearExpectations(&service_);
+
+ EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
+ EXPECT_TRUE(watcher_.ForceRefresh().ok());
+ EXPECT_EQ(watcher_.GetServices().size(), size_t{0});
+ testing::Mock::VerifyAndClearExpectations(&service_);
+}
+
+TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) {
+ const DnsSdInstanceRecord record("Instance", kCastServiceId, kCastDomainId,
+ kEndpointV4, DnsSdTxtRecord{});
+ const DnsSdInstanceRecord record2("Instance2", kCastServiceId, kCastDomainId,
+ kEndpointV4, DnsSdTxtRecord{});
+
+ EXPECT_FALSE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+
+ CreateNewInstance(record);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+
+ CreateExistingInstance(record);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+
+ UpdateNonExistingInstance(record2);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+
+ DeleteNonExistingInstance(record2);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+
+ CreateNewInstance(record2);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ UpdateExistingInstance(record2);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ UpdateExistingInstance(record);
+ EXPECT_TRUE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ DeleteExistingInstance(record);
+ EXPECT_FALSE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ UpdateNonExistingInstance(record);
+ EXPECT_FALSE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ DeleteNonExistingInstance(record);
+ EXPECT_FALSE(ContainsService(record));
+ EXPECT_TRUE(ContainsService(record2));
+
+ DeleteExistingInstance(record2);
+ EXPECT_FALSE(ContainsService(record));
+ EXPECT_FALSE(ContainsService(record2));
+}
+
+} // namespace discovery
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/docs/style_guide.md b/chromium/third_party/openscreen/src/docs/style_guide.md
index 85e550e7af2..08a807c8ae0 100644
--- a/chromium/third_party/openscreen/src/docs/style_guide.md
+++ b/chromium/third_party/openscreen/src/docs/style_guide.md
@@ -9,10 +9,23 @@ C++14 language and library features are allowed in the Open Screen Library
according to the
[C++14 use in Chromium](https://chromium-cpp.appspot.com#core-whitelist) guidelines.
+## Modifications to the Chromium C++ Guidelines
+
+- `<functional>` and `std::function` objects are allowed.
+- `<chrono>` is allowed and encouraged for representation of time.
+- Abseil types are allowed based on the whitelist in [DEPS](https://chromium.googlesource.com/openscreen/+/refs/heads/master/DEPS).
+- **Do not** use Abseil types in public APIs.
+- `<thread>` and `<mutex>` are allowed, but discouraged from general use as the
+ library only needs to handle threading in very specific places;
+ see [threading.md](threading.md).
+
## Open Screen Library Features
- For public API functions that return values or errors, please return
- [`ErrorOr<T>`](https://chromium.googlesource.com/openscreen/+/master/base/error.h).
+ [`ErrorOr<T>`](https://chromium.googlesource.com/openscreen/+/master/platform/base/error.h).
+- In the implementation of public APIs invoked by the embedder, use
+ `OSP_DCHECK(TaskRunner::IsRunningOnTaskRunner())` to catch thread safety
+ problems early.
## Style Addenda
diff --git a/chromium/third_party/openscreen/src/docs/threading.md b/chromium/third_party/openscreen/src/docs/threading.md
new file mode 100644
index 00000000000..78a55b2d7d8
--- /dev/null
+++ b/chromium/third_party/openscreen/src/docs/threading.md
@@ -0,0 +1,21 @@
+# Threading
+
+The Open Screen Library is **single-threaded**; all of its code is intended to be
+run on a single sequence, with a few exceptions noted below.
+
+A library client **must** invoke all library APIs on the same sequence that is
+used to run tasks on the client's
+[TaskRunner implementation](https://chromium.googlesource.com/openscreen/+/refs/heads/master/platform/api/task_runner.h).
+
+## Exceptions
+
+* The [trace logging](trace_logging.md) framework is thread-safe.
+* The TaskRunner itself is thread-safe.
+* The [POSIX platform implementation](https://chromium.googlesource.com/openscreen/+/refs/heads/master/platform/impl/)
+ starts a network thread, and handles interactions between that thread and the
+ TaskRunner internally.
+
+
+
+
+
diff --git a/chromium/third_party/openscreen/src/docs/trace_logging.md b/chromium/third_party/openscreen/src/docs/trace_logging.md
index 0139b3f2348..8bb60915656 100644
--- a/chromium/third_party/openscreen/src/docs/trace_logging.md
+++ b/chromium/third_party/openscreen/src/docs/trace_logging.md
@@ -83,31 +83,27 @@ call. As with scoped traces, the result must be some Error::Code enum value.
## Tracing Functions
All of the below functions rely on the Platform Layer's IsTraceLoggingEnabled()
function. When logging is disabled, either for the specific category of trace
-logging which the Macro specifies or for TraceCategory::Any in all other caes,
+logging which the Macro specifies or for TraceCategory::Any in all other cases,
the below functions will be treated as a NoOp.
### Synchronous Tracing
- `TRACE_SCOPED(category, name)`
- If logging is enabled for the provided category, this function will trace
- the current function until the current scope ends with name as provided.
- When this call is used, the Trace ID Hierarchy will be determined
- automatically and the caller does not need to worry about it and, as such,
- **this call should be used in the majority of synchronous tracing cases**.
-
- `TRACE_SCOPED(category, name, traceId, parentId, rootId)`
- If logging is enabled for the provided category, this function will trace
- the current function until the current scope ends with name as provided. The
- Trace ID used for tracing this function will be set to the one provided, as
- will the parent and root ids. Each of Trace ID, Parent ID, and Root ID is
- optional, so providing only a subset of these values is also valid if the
- caller only desires to set specific ones.
-
- `TRACE_SCOPED(category, name, traceIdHierarchy)`
- This call is intended for use in conjunction with the TRACE_HIERARCHY macro
- (as described below). If logging is enabled for the provided category, this
- function will trace the current function until the current scope ends with
- name as provided. The Trace ID Hierarchy will be set as provided in the
- provided Trace ID Hierarchy parameter.
+ ```c++
+ TRACE_SCOPED(category, name)
+ TRACE_SCOPED(category, name, traceId, parentId, rootId)
+ TRACE_SCOPED(category, name, traceIdHierarchy)
+ ```
+ If logging is enabled for the provided |category|, trace the current scope. The scope
+ should encompass the operation described by |name|. The latter two uses of this macro are
+ for manually providing the trace ID hierarchy; the first auto-generates a new trace ID for
+ this scope and sets its parent trace ID to that of the encompassing scope (if any).
+
+ ```c++
+ TRACE_DEFAULT_SCOPED(category)
+ TRACE_DEFAULT_SCOPED(category, traceId, parentId, rootId)
+ TRACE_DEFAULT_SCOPED(category, traceIdHierarchy)
+ ```
+ Same as TRACE_SCOPED(), but use the current function/method signature as the operation
+ name.
### Asynchronous Tracing
`TRACE_ASYNC_START(category, name)`
@@ -229,7 +225,7 @@ For an embedder to create a custom TraceLogging implementation:
called frequently (especially `IsLoggingEnabled(TraceCategory)`) and are often
in the critical execution path of the library's code.
-2. *Call `openscreen::platform::StartTracing()` and `StopTracing()`*
+2. *Call `openscreen::StartTracing()` and `StopTracing()`*
These activate/deactivate tracing by providing the TraceLoggingPlatform
instance and later clearing references to it.
diff --git a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg
index 7f3792bd8ad..d51fcc00bee 100644
--- a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg
+++ b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg
@@ -32,21 +32,20 @@ config_groups {
name: "openscreen/try/linux64_tsan"
}
builders {
+ name: "openscreen/try/linux_arm64_debug"
+ }
+ builders {
name: "openscreen/try/mac_debug"
}
builders {
name: "openscreen/try/openscreen_presubmit"
}
- # Chromium bots are declared "experimental" but at 100% so they always run
- # but aren't considered part of the commit queue pass/fail.
builders {
name: "openscreen/try/chromium_linux64_debug"
- experiment_percentage: 100
}
builders {
name: "openscreen/try/chromium_mac_debug"
- experiment_percentage: 100
}
retry_config {
diff --git a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg
index 0b35cbf4365..447571db82f 100644
--- a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg
+++ b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg
@@ -74,7 +74,22 @@ builder_mixins {
builder_mixins {
name: "mac"
+
+ # NOTE: The OS version here will determine which version of XCode is being
+ # used. Relevant links; so you and I never have to spend hours finding this
+ # stuff all over again to fix things like https://crbug.com/openscreen/86:
+ #
+ # 1. The recipe code that uses the "osx_sdk" recipe module:
+ #
+ # https://cs.chromium.org/chromium/build/scripts/slave/recipes
+ # /openscreen.py?rcl=671f9f1c5f5bef81d0a39973aa8729cc83bb290e&l=74
+ #
+ # 2. The XCode version look-up table in the "osx_sdk" recipe module:
+ #
+ # https://cs.chromium.org/chromium/tools/depot_tools/recipes/recipe_modules
+ # /osx_sdk/api.py?rcl=fe18a43d590a5eac0d58e7e555b024746ba290ad&l=26
dimensions: "os:Mac-10.13"
+
caches: {
# Cache for mac_toolchain tool and XCode.app used in recipes.
name: "osx_sdk"
@@ -91,6 +106,14 @@ builder_mixins {
}
builder_mixins {
+ name: "arm64"
+ dimensions: "cpu:x86-64"
+ recipe {
+ properties: "target_cpu:arm64"
+ }
+}
+
+builder_mixins {
name: "chromium"
recipe: {
name: "chromium"
@@ -137,6 +160,13 @@ buckets {
}
builders {
+ name: "linux_arm64_debug"
+ mixins: "linux"
+ mixins: "arm64"
+ mixins: "debug"
+ }
+
+ builders {
name: "mac_debug"
mixins: "mac"
mixins: "debug"
@@ -201,6 +231,13 @@ buckets: {
}
builders {
+ name: "linux_arm64_debug"
+ mixins: "linux"
+ mixins: "arm64"
+ mixins: "debug"
+ }
+
+ builders {
name: "mac_debug"
mixins: "mac"
mixins: "debug"
@@ -212,6 +249,7 @@ buckets: {
recipe {
name: "run_presubmit"
properties: "repo_name:openscreen"
+ properties: "runhooks:true"
}
mixins: "linux"
mixins: "x64"
diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg
index 458a1c80c79..b50a49ce08a 100644
--- a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg
+++ b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg
@@ -26,6 +26,12 @@ consoles {
}
builders {
+ name: "buildbucket/luci.openscreen.ci/linux_arm64_debug"
+ category: "linux|arm64"
+ short_name: "arm64"
+ }
+
+ builders {
name: "buildbucket/luci.openscreen.ci/mac_debug"
category: "mac"
short_name: "dbg"
@@ -70,8 +76,26 @@ consoles {
}
builders {
+ name: "buildbucket/luci.openscreen.ci/linux_arm64_debug"
+ category: "linux|arm64"
+ short_name: "arm64"
+ }
+
+ builders {
name: "buildbucket/luci.openscreen.try/mac_debug"
category: "mac"
short_name: "dbg"
}
+
+ builders {
+ name: "buildbucket/luci.openscreen.try/chromium_linux64_debug"
+ category: "chromium fyi"
+ short_name: "linux64"
+ }
+
+ builders {
+ name: "buildbucket/luci.openscreen.try/chromium_mac_debug"
+ category: "chromium fyi"
+ short_name: "mac"
+ }
}
diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg
index 4c2822dc4c3..39efee3f7b7 100644
--- a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg
+++ b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg
@@ -25,6 +25,7 @@ trigger {
triggers: "linux64_debug"
triggers: "linux64_gcc_debug"
triggers: "linux64_tsan"
+ triggers: "linux_arm64_debug"
triggers: "mac_debug"
}
@@ -71,6 +72,16 @@ job {
}
job {
+ id: "linux_arm64_debug"
+ acl_sets: "default"
+ buildbucket: {
+ server: "cr-buildbucket.appspot.com"
+ bucket: "luci.openscreen.ci"
+ builder: "linux64_arm64_debug"
+ }
+}
+
+job {
id: "mac_debug"
acl_sets: "default"
buildbucket: {
diff --git a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc
index 37c6bb432fa..daa5fa0c596 100644
--- a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc
+++ b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc
@@ -36,9 +36,6 @@
#include "third_party/tinycbor/src/src/cbor.h"
#include "util/trace_logging.h"
-using openscreen::platform::Clock;
-using openscreen::platform::PlatformClientPosix;
-
namespace {
const char* kReceiverLogFilename = "_recv_fifo";
@@ -431,8 +428,7 @@ void ListenerDemo() {
listener_config, &listener_observer,
PlatformClientPosix::GetInstance()->GetTaskRunner());
- MessageDemuxer demuxer(platform::Clock::now,
- MessageDemuxer::kDefaultBufferLimit);
+ MessageDemuxer demuxer(Clock::now, MessageDemuxer::kDefaultBufferLimit);
DemoConnectionClientObserver client_observer;
auto connection_client = ProtocolConnectionClientFactory::Create(
&demuxer, &client_observer,
@@ -440,7 +436,7 @@ void ListenerDemo() {
auto* network_service = NetworkServiceManager::Create(
std::move(mdns_listener), nullptr, std::move(connection_client), nullptr);
- auto controller = std::make_unique<Controller>(platform::Clock::now);
+ auto controller = std::make_unique<Controller>(Clock::now);
network_service->GetMdnsServiceListener()->Start();
network_service->GetProtocolConnectionClient()->Start();
@@ -524,8 +520,7 @@ void PublisherDemo(absl::string_view friendly_name) {
PlatformClientPosix::GetInstance()->GetTaskRunner());
ServerConfig server_config;
- for (const platform::InterfaceInfo& interface :
- platform::GetNetworkInterfaces()) {
+ for (const InterfaceInfo& interface : GetNetworkInterfaces()) {
OSP_VLOG << "Found interface: " << interface;
if (!interface.addresses.empty()) {
server_config.connection_endpoints.push_back(
@@ -535,8 +530,7 @@ void PublisherDemo(absl::string_view friendly_name) {
OSP_LOG_IF(WARN, server_config.connection_endpoints.empty())
<< "No network interfaces had usable addresses for mDNS publishing.";
- MessageDemuxer demuxer(platform::Clock::now,
- MessageDemuxer::kDefaultBufferLimit);
+ MessageDemuxer demuxer(Clock::now, MessageDemuxer::kDefaultBufferLimit);
DemoConnectionServerObserver server_observer;
auto connection_server = ProtocolConnectionServerFactory::Create(
server_config, &demuxer, &server_observer,
@@ -588,7 +582,9 @@ InputArgs GetInputArgs(int argc, char** argv) {
}
int main(int argc, char** argv) {
- using openscreen::platform::LogLevel;
+ using openscreen::Clock;
+ using openscreen::LogLevel;
+ using openscreen::PlatformClientPosix;
std::cout << "Usage: osp_demo [-v] [friendly_name]" << std::endl
<< "-v: enable more verbose logging" << std::endl
@@ -602,11 +598,11 @@ int main(int argc, char** argv) {
const bool is_receiver_demo = !args.friendly_server_name.empty();
const char* log_filename =
is_receiver_demo ? kReceiverLogFilename : kControllerLogFilename;
- openscreen::platform::SetLogFifoOrDie(log_filename);
+ openscreen::SetLogFifoOrDie(log_filename);
LogLevel level = args.is_verbose ? LogLevel::kVerbose : LogLevel::kInfo;
- openscreen::platform::SetLogLevel(level);
- openscreen::platform::TextTraceLoggingPlatform text_logging_platform;
+ openscreen::SetLogLevel(level);
+ openscreen::TextTraceLoggingPlatform text_logging_platform;
PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50});
diff --git a/chromium/third_party/openscreen/src/osp/impl/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/BUILD.gn
index b446aef9bd3..779300bb483 100644
--- a/chromium/third_party/openscreen/src/osp/impl/BUILD.gn
+++ b/chromium/third_party/openscreen/src/osp/impl/BUILD.gn
@@ -43,6 +43,7 @@ source_set("impl") {
]
deps = [
"../../platform",
+ "../../third_party/abseil",
"../../util",
"quic",
]
@@ -64,6 +65,7 @@ if (use_chromium_quic) {
deps = [
"../../osp/msgs",
"../../platform",
+ "../../third_party/abseil",
"../../third_party/chromium_quic",
"../../util",
"quic",
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn
index 4779f5264c0..cd051d1750d 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/BUILD.gn
@@ -15,6 +15,7 @@ source_set("mdns_interface") {
public_deps = [
"../../../../platform",
+ "../../../../third_party/abseil",
"../../../../util",
]
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc
index 691658e0d8a..a6c0e508165 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc
@@ -37,9 +37,6 @@
// shouldn't be expected to be a source of truth, nor should it be expected to
// be correct after running for a long time.
-using openscreen::platform::Clock;
-using openscreen::platform::PlatformClientPosix;
-
namespace openscreen {
namespace osp {
namespace {
@@ -59,22 +56,21 @@ struct Service {
std::vector<std::string> txt;
};
-class DemoSocketClient : public platform::UdpSocket::Client {
+class DemoSocketClient : public UdpSocket::Client {
public:
DemoSocketClient(MdnsResponderAdapterImpl* mdns) : mdns_(mdns) {}
- void OnError(platform::UdpSocket* socket, Error error) override {
+ void OnError(UdpSocket* socket, Error error) override {
// TODO(crbug.com/openscreen/66): Change to OSP_LOG_FATAL.
OSP_LOG_ERROR << "configuration failed for interface " << error.message();
OSP_CHECK(false);
}
- void OnSendError(platform::UdpSocket* socket, Error error) override {
+ void OnSendError(UdpSocket* socket, Error error) override {
OSP_UNIMPLEMENTED();
}
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override {
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override {
mdns_->OnRead(socket, std::move(packet));
}
@@ -128,20 +124,20 @@ void SignalThings() {
OSP_LOG << "signal handlers setup" << std::endl << "pid: " << getpid();
}
-std::vector<platform::UdpSocketUniquePtr> SetUpMulticastSockets(
- platform::TaskRunner* task_runner,
- const std::vector<platform::NetworkInterfaceIndex>& index_list,
- platform::UdpSocket::Client* client) {
- std::vector<platform::UdpSocketUniquePtr> sockets;
+std::vector<std::unique_ptr<UdpSocket>> SetUpMulticastSockets(
+ TaskRunner* task_runner,
+ const std::vector<NetworkInterfaceIndex>& index_list,
+ UdpSocket::Client* client) {
+ std::vector<std::unique_ptr<UdpSocket>> sockets;
for (const auto ifindex : index_list) {
auto create_result =
- platform::UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353});
+ UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353});
if (!create_result) {
OSP_LOG_ERROR << "failed to create IPv4 socket for interface " << ifindex
<< ": " << create_result.error().message();
continue;
}
- platform::UdpSocketUniquePtr socket = std::move(create_result.value());
+ std::unique_ptr<UdpSocket> socket = std::move(create_result.value());
socket->JoinMulticastGroup(IPAddress{224, 0, 0, 251}, ifindex);
socket->SetMulticastOutboundInterface(ifindex);
@@ -250,7 +246,7 @@ void HandleEvents(MdnsResponderAdapterImpl* mdns_adapter) {
}
}
-void BrowseDemo(platform::TaskRunner* task_runner,
+void BrowseDemo(TaskRunner* task_runner,
const std::string& service_name,
const std::string& service_protocol,
const std::string& service_instance) {
@@ -269,9 +265,8 @@ void BrowseDemo(platform::TaskRunner* task_runner,
auto mdns_adapter = std::make_unique<MdnsResponderAdapterImpl>();
mdns_adapter->Init();
mdns_adapter->SetHostLabel("gigliorononomicon");
- const std::vector<platform::InterfaceInfo> interfaces =
- platform::GetNetworkInterfaces();
- std::vector<platform::NetworkInterfaceIndex> index_list;
+ const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces();
+ std::vector<NetworkInterfaceIndex> index_list;
for (const auto& interface : interfaces) {
OSP_LOG << "Found interface: " << interface;
if (!interface.addresses.empty()) {
@@ -290,10 +285,10 @@ void BrowseDemo(platform::TaskRunner* task_runner,
// Listen on all interfaces.
auto socket_it = sockets.begin();
- for (platform::NetworkInterfaceIndex index : index_list) {
+ for (NetworkInterfaceIndex index : index_list) {
const auto& interface =
*std::find_if(interfaces.begin(), interfaces.end(),
- [index](const openscreen::platform::InterfaceInfo& info) {
+ [index](const openscreen::InterfaceInfo& info) {
return info.index == index;
});
// Pick any address for the given interface.
@@ -308,7 +303,7 @@ void BrowseDemo(platform::TaskRunner* task_runner,
{{"k1", "yurtle"}, {"k2", "turtle"}});
}
- for (const platform::UdpSocketUniquePtr& socket : sockets) {
+ for (const std::unique_ptr<UdpSocket>& socket : sockets) {
mdns_adapter->StartPtrQuery(socket.get(), service_type.value());
}
@@ -332,7 +327,7 @@ void BrowseDemo(platform::TaskRunner* task_runner,
for (const auto& s : *g_services) {
LogService(s.second);
}
- for (const platform::UdpSocketUniquePtr& socket : sockets) {
+ for (const std::unique_ptr<UdpSocket>& socket : sockets) {
mdns_adapter->DeregisterInterface(socket.get());
}
mdns_adapter->Close();
@@ -343,7 +338,10 @@ void BrowseDemo(platform::TaskRunner* task_runner,
} // namespace openscreen
int main(int argc, char** argv) {
- openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kVerbose);
+ using openscreen::Clock;
+ using openscreen::PlatformClientPosix;
+
+ openscreen::SetLogLevel(openscreen::LogLevel::kVerbose);
std::string service_instance;
std::string service_type("_openscreen._udp");
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc
index 810c7bbdfce..edf25503dd1 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.cc
@@ -9,7 +9,7 @@ namespace osp {
QueryEventHeader::QueryEventHeader() = default;
QueryEventHeader::QueryEventHeader(QueryEventHeader::Type response_type,
- platform::UdpSocket* socket)
+ UdpSocket* socket)
: response_type(response_type), socket(socket) {}
QueryEventHeader::QueryEventHeader(QueryEventHeader&&) noexcept = default;
QueryEventHeader::~QueryEventHeader() = default;
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h
index 376a9899fa4..66083d57a72 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h
@@ -29,13 +29,13 @@ struct QueryEventHeader {
};
QueryEventHeader();
- QueryEventHeader(Type response_type, platform::UdpSocket* socket);
+ QueryEventHeader(Type response_type, UdpSocket* socket);
QueryEventHeader(QueryEventHeader&&) noexcept;
~QueryEventHeader();
QueryEventHeader& operator=(QueryEventHeader&&) noexcept;
Type response_type;
- platform::UdpSocket* socket;
+ UdpSocket* socket;
};
struct PtrEvent {
@@ -160,7 +160,7 @@ enum class MdnsResponderErrorCode {
// called after any sequence of calls to mDNSResponder. It also returns a
// timeout value, after which it must be called again (e.g. for maintaining its
// cache).
-class MdnsResponderAdapter : public platform::UdpSocket::Client {
+class MdnsResponderAdapter : public UdpSocket::Client {
public:
MdnsResponderAdapter();
virtual ~MdnsResponderAdapter() = 0;
@@ -184,14 +184,14 @@ class MdnsResponderAdapter : public platform::UdpSocket::Client {
// mDNSResponder. |socket| will be used to identify which interface received
// the data in OnDataReceived and will be used to send data via the platform
// layer.
- virtual Error RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocket* socket) = 0;
- virtual Error DeregisterInterface(platform::UdpSocket* socket) = 0;
+ virtual Error RegisterInterface(const InterfaceInfo& interface_info,
+ const IPSubnet& interface_address,
+ UdpSocket* socket) = 0;
+ virtual Error DeregisterInterface(UdpSocket* socket) = 0;
// Returns the time period after which this method must be called again, if
// any.
- virtual platform::Clock::duration RunTasks() = 0;
+ virtual Clock::duration RunTasks() = 0;
virtual std::vector<PtrEvent> TakePtrResponses() = 0;
virtual std::vector<SrvEvent> TakeSrvResponses() = 0;
@@ -200,33 +200,33 @@ class MdnsResponderAdapter : public platform::UdpSocket::Client {
virtual std::vector<AaaaEvent> TakeAaaaResponses() = 0;
virtual MdnsResponderErrorCode StartPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) = 0;
virtual MdnsResponderErrorCode StartSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) = 0;
virtual MdnsResponderErrorCode StartTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) = 0;
- virtual MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket,
+ virtual MdnsResponderErrorCode StartAQuery(UdpSocket* socket,
const DomainName& domain_name) = 0;
virtual MdnsResponderErrorCode StartAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) = 0;
virtual MdnsResponderErrorCode StopPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) = 0;
virtual MdnsResponderErrorCode StopSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) = 0;
virtual MdnsResponderErrorCode StopTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) = 0;
- virtual MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket,
+ virtual MdnsResponderErrorCode StopAQuery(UdpSocket* socket,
const DomainName& domain_name) = 0;
virtual MdnsResponderErrorCode StopAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) = 0;
// The following methods concern advertising a service via mDNS. The
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc
index 6b986ded342..3feb8e7aaa5 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc
@@ -13,8 +13,6 @@
#include "util/logging.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace osp {
namespace {
@@ -206,7 +204,7 @@ Error MdnsResponderAdapterImpl::Init() {
}
void MdnsResponderAdapterImpl::Close() {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::Close");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::Close");
mDNS_StartExit(&mdns_);
// Let all services send goodbyes.
while (!service_records_.empty()) {
@@ -228,7 +226,7 @@ void MdnsResponderAdapterImpl::Close() {
}
Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::SetHostLabel");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::SetHostLabel");
if (host_label.size() > DomainName::kDomainNameMaxLabelLength)
return Error::Code::kDomainNameTooLong;
@@ -242,10 +240,10 @@ Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) {
}
Error MdnsResponderAdapterImpl::RegisterInterface(
- const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocket* socket) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ const InterfaceInfo& interface_info,
+ const IPSubnet& interface_address,
+ UdpSocket* socket) {
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::RegisterInterface");
OSP_DCHECK(socket);
@@ -272,8 +270,9 @@ Error MdnsResponderAdapterImpl::RegisterInterface(
}
static_assert(sizeof(info.MAC.b) == sizeof(interface_info.hardware_address),
- "MAC addresss size mismatch.");
- memcpy(info.MAC.b, interface_info.hardware_address, sizeof(info.MAC.b));
+ "MAC address size mismatch.");
+ memcpy(info.MAC.b, interface_info.hardware_address.data(),
+ sizeof(info.MAC.b));
info.McastTxRx = 1;
platform_storage_.sockets.push_back(socket);
auto result = mDNS_RegisterInterface(&mdns_, &info, mDNSfalse);
@@ -284,9 +283,8 @@ Error MdnsResponderAdapterImpl::RegisterInterface(
: Error::Code::kMdnsRegisterFailure;
}
-Error MdnsResponderAdapterImpl::DeregisterInterface(
- platform::UdpSocket* socket) {
- TRACE_SCOPED(TraceCategory::mDNS,
+Error MdnsResponderAdapterImpl::DeregisterInterface(UdpSocket* socket) {
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::DeregisterInterface");
const auto info_it = responder_interface_info_.find(socket);
if (info_it == responder_interface_info_.end())
@@ -304,15 +302,14 @@ Error MdnsResponderAdapterImpl::DeregisterInterface(
responder_interface_info_.erase(info_it);
return Error::None();
}
-void MdnsResponderAdapterImpl::OnRead(
- platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet_or_error) {
+void MdnsResponderAdapterImpl::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet_or_error) {
if (packet_or_error.is_error()) {
return;
}
- platform::UdpPacket packet = std::move(packet_or_error.value());
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::OnRead");
+ UdpPacket packet = std::move(packet_or_error.value());
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::OnRead");
mDNSAddr src;
if (packet.source().address.IsV4()) {
src.type = mDNSAddrType_IPv4;
@@ -341,20 +338,18 @@ void MdnsResponderAdapterImpl::OnRead(
reinterpret_cast<mDNSInterfaceID>(packet.socket()));
}
-void MdnsResponderAdapterImpl::OnSendError(platform::UdpSocket* socket,
- Error error) {
+void MdnsResponderAdapterImpl::OnSendError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
-void MdnsResponderAdapterImpl::OnError(platform::UdpSocket* socket,
- Error error) {
+void MdnsResponderAdapterImpl::OnError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
-platform::Clock::duration MdnsResponderAdapterImpl::RunTasks() {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::RunTasks");
+Clock::duration MdnsResponderAdapterImpl::RunTasks() {
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RunTasks");
mDNS_Execute(&mdns_);
@@ -409,9 +404,9 @@ std::vector<AaaaEvent> MdnsResponderAdapterImpl::TakeAaaaResponses() {
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartPtrQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartPtrQuery");
auto& ptr_questions = socket_to_questions_[socket].ptr;
if (ptr_questions.find(service_type) != ptr_questions.end())
return MdnsResponderErrorCode::kNoError;
@@ -456,9 +451,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartSrvQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartSrvQuery");
if (!service_instance.EndsWithLocalDomain())
return MdnsResponderErrorCode::kInvalidParameters;
@@ -494,9 +489,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartTxtQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartTxtQuery");
if (!service_instance.EndsWithLocalDomain())
return MdnsResponderErrorCode::kInvalidParameters;
@@ -532,9 +527,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartAQuery");
if (!domain_name.EndsWithLocalDomain())
return MdnsResponderErrorCode::kInvalidParameters;
@@ -570,9 +565,10 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAaaaQuery");
+ TRACE_SCOPED(TraceCategory::kMdns,
+ "MdnsResponderAdapterImpl::StartAaaaQuery");
if (!domain_name.EndsWithLocalDomain())
return MdnsResponderErrorCode::kInvalidParameters;
@@ -608,9 +604,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopPtrQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopPtrQuery");
auto interface_entry = socket_to_questions_.find(socket);
if (interface_entry == socket_to_questions_.end())
return MdnsResponderErrorCode::kNoError;
@@ -626,9 +622,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopSrvQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopSrvQuery");
auto interface_entry = socket_to_questions_.find(socket);
if (interface_entry == socket_to_questions_.end())
return MdnsResponderErrorCode::kNoError;
@@ -644,9 +640,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopTxtQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopTxtQuery");
auto interface_entry = socket_to_questions_.find(socket);
if (interface_entry == socket_to_questions_.end())
return MdnsResponderErrorCode::kNoError;
@@ -662,9 +658,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAQuery");
auto interface_entry = socket_to_questions_.find(socket);
if (interface_entry == socket_to_questions_.end())
return MdnsResponderErrorCode::kNoError;
@@ -680,9 +676,9 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery(
}
MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAaaaQuery");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAaaaQuery");
auto interface_entry = socket_to_questions_.find(socket);
if (interface_entry == socket_to_questions_.end())
return MdnsResponderErrorCode::kNoError;
@@ -704,7 +700,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::RegisterService(
const DomainName& target_host,
uint16_t target_port,
const std::map<std::string, std::string>& txt_data) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::RegisterService");
OSP_DCHECK(IsValidServiceName(service_name));
OSP_DCHECK(IsValidServiceProtocol(service_protocol));
@@ -749,7 +745,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::DeregisterService(
const std::string& service_instance,
const std::string& service_name,
const std::string& service_protocol) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::DeregisterService");
domainlabel instance;
domainlabel name;
@@ -779,7 +775,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::UpdateTxtData(
const std::string& service_name,
const std::string& service_protocol,
const std::map<std::string, std::string>& txt_data) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::UpdateTxtData");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::UpdateTxtData");
domainlabel instance;
domainlabel name;
domainlabel protocol;
@@ -813,7 +809,8 @@ void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m,
DNSQuestion* question,
const ResourceRecord* answer,
QC_result added) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::AQueryCallback");
+ TRACE_SCOPED(TraceCategory::kMdns,
+ "MdnsResponderAdapterImpl::AQueryCallback");
OSP_DCHECK(question);
OSP_DCHECK(answer);
OSP_DCHECK_EQ(answer->rrtype, kDNSType_A);
@@ -834,8 +831,8 @@ void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m,
OSP_DCHECK_EQ(added, QC_addnocache);
}
adapter->a_responses_.emplace_back(
- QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>(
- answer->InterfaceID)},
+ QueryEventHeader{event_type,
+ reinterpret_cast<UdpSocket*>(answer->InterfaceID)},
std::move(domain), address);
}
@@ -844,7 +841,7 @@ void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m,
DNSQuestion* question,
const ResourceRecord* answer,
QC_result added) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::AaaaQueryCallback");
OSP_DCHECK(question);
OSP_DCHECK(answer);
@@ -866,8 +863,8 @@ void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m,
OSP_DCHECK_EQ(added, QC_addnocache);
}
adapter->aaaa_responses_.emplace_back(
- QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>(
- answer->InterfaceID)},
+ QueryEventHeader{event_type,
+ reinterpret_cast<UdpSocket*>(answer->InterfaceID)},
std::move(domain), address);
}
@@ -876,7 +873,7 @@ void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m,
DNSQuestion* question,
const ResourceRecord* answer,
QC_result added) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::PtrQueryCallback");
OSP_DCHECK(question);
OSP_DCHECK(answer);
@@ -897,8 +894,8 @@ void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m,
OSP_DCHECK_EQ(added, QC_addnocache);
}
adapter->ptr_responses_.emplace_back(
- QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>(
- answer->InterfaceID)},
+ QueryEventHeader{event_type,
+ reinterpret_cast<UdpSocket*>(answer->InterfaceID)},
std::move(result));
}
@@ -907,7 +904,7 @@ void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m,
DNSQuestion* question,
const ResourceRecord* answer,
QC_result added) {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::SrvQueryCallback");
OSP_DCHECK(question);
OSP_DCHECK(answer);
@@ -932,8 +929,8 @@ void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m,
OSP_DCHECK_EQ(added, QC_addnocache);
}
adapter->srv_responses_.emplace_back(
- QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>(
- answer->InterfaceID)},
+ QueryEventHeader{event_type,
+ reinterpret_cast<UdpSocket*>(answer->InterfaceID)},
std::move(service), std::move(result),
GetNetworkOrderPort(answer->rdata->u.srv.port));
}
@@ -963,8 +960,8 @@ void MdnsResponderAdapterImpl::TxtQueryCallback(mDNS* m,
OSP_DCHECK_EQ(added, QC_addnocache);
}
adapter->txt_responses_.emplace_back(
- QueryEventHeader{event_type, reinterpret_cast<platform::UdpSocket*>(
- answer->InterfaceID)},
+ QueryEventHeader{event_type,
+ reinterpret_cast<UdpSocket*>(answer->InterfaceID)},
std::move(service), std::move(lines));
}
@@ -992,10 +989,10 @@ void MdnsResponderAdapterImpl::ServiceCallback(mDNS* m,
}
void MdnsResponderAdapterImpl::AdvertiseInterfaces() {
- TRACE_SCOPED(TraceCategory::mDNS,
+ TRACE_SCOPED(TraceCategory::kMdns,
"MdnsResponderAdapterImpl::AdvertiseInterfaces");
for (auto& info : responder_interface_info_) {
- platform::UdpSocket* socket = info.first;
+ UdpSocket* socket = info.first;
NetworkInterfaceInfo& interface_info = info.second;
mDNS_SetupResourceRecord(&interface_info.RR_A, /** RDataStorage */ nullptr,
reinterpret_cast<mDNSInterfaceID>(socket),
@@ -1027,8 +1024,7 @@ void MdnsResponderAdapterImpl::DeadvertiseInterfaces() {
}
}
-void MdnsResponderAdapterImpl::RemoveQuestionsIfEmpty(
- platform::UdpSocket* socket) {
+void MdnsResponderAdapterImpl::RemoveQuestionsIfEmpty(UdpSocket* socket) {
auto entry = socket_to_questions_.find(socket);
bool empty = entry->second.a.empty() || entry->second.aaaa.empty() ||
entry->second.ptr.empty() || entry->second.srv.empty() ||
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h
index 1e5ddeb0e39..80669e577d0 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h
@@ -29,17 +29,16 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter {
Error SetHostLabel(const std::string& host_label) override;
- Error RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocket* socket) override;
- Error DeregisterInterface(platform::UdpSocket* socket) override;
+ Error RegisterInterface(const InterfaceInfo& interface_info,
+ const IPSubnet& interface_address,
+ UdpSocket* socket) override;
+ Error DeregisterInterface(UdpSocket* socket) override;
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnError(platform::UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnError(UdpSocket* socket, Error error) override;
- platform::Clock::duration RunTasks() override;
+ Clock::duration RunTasks() override;
std::vector<PtrEvent> TakePtrResponses() override;
std::vector<SrvEvent> TakeSrvResponses() override;
@@ -47,29 +46,29 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter {
std::vector<AEvent> TakeAResponses() override;
std::vector<AaaaEvent> TakeAaaaResponses() override;
- MdnsResponderErrorCode StartPtrQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket,
const DomainName& service_type) override;
MdnsResponderErrorCode StartSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
MdnsResponderErrorCode StartTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
- MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartAQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StartAaaaQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StopPtrQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket,
const DomainName& service_type) override;
MdnsResponderErrorCode StopSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
MdnsResponderErrorCode StopTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
- MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopAQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StopAaaaQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket,
const DomainName& domain_name) override;
MdnsResponderErrorCode RegisterService(
@@ -124,7 +123,7 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter {
void AdvertiseInterfaces();
void DeadvertiseInterfaces();
- void RemoveQuestionsIfEmpty(platform::UdpSocket* socket);
+ void RemoveQuestionsIfEmpty(UdpSocket* socket);
CacheEntity rr_cache_[kRrCacheSize];
@@ -136,10 +135,9 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter {
// platform sockets.
mDNS_PlatformSupport platform_storage_;
- std::map<platform::UdpSocket*, Questions> socket_to_questions_;
+ std::map<UdpSocket*, Questions> socket_to_questions_;
- std::map<platform::UdpSocket*, NetworkInterfaceInfo>
- responder_interface_info_;
+ std::map<UdpSocket*, NetworkInterfaceInfo> responder_interface_info_;
std::vector<AEvent> a_responses_;
std::vector<AaaaEvent> aaaa_responses_;
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc
index 759fe44d470..29b76679665 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc
@@ -49,7 +49,7 @@ TEST(MdnsResponderAdapterImplTest, ExampleData) {
'p', 5, 'l', 'o', 'c', 'a', 'l', 0}};
const IPEndpoint mdns_endpoint{{224, 0, 0, 251}, 5353};
- platform::UdpPacket packet(std::begin(data), std::end(data));
+ UdpPacket packet(std::begin(data), std::end(data));
packet.set_source({{192, 168, 0, 2}, 6556});
packet.set_destination(mdns_endpoint);
packet.set_socket(nullptr);
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc
index 01603be4f16..4e204e37055 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc
@@ -18,7 +18,6 @@
#include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h"
#include "util/logging.h"
-using openscreen::platform::Clock;
using std::chrono::duration_cast;
using std::chrono::hours;
using std::chrono::milliseconds;
@@ -44,8 +43,7 @@ mStatus mDNSPlatformSendUDP(const mDNS* m,
UDPSocket* src,
const mDNSAddr* dst,
mDNSIPPort dstport) {
- auto* const socket =
- reinterpret_cast<openscreen::platform::UdpSocket*>(InterfaceID);
+ auto* const socket = reinterpret_cast<openscreen::UdpSocket*>(InterfaceID);
const auto socket_it =
std::find(m->p->sockets.begin(), m->p->sockets.end(), socket);
if (socket_it == m->p->sockets.end())
@@ -111,6 +109,8 @@ mStatus mDNSPlatformTimeInit() {
}
mDNSs32 mDNSPlatformRawTime() {
+ using openscreen::Clock;
+
const Clock::time_point now = Clock::now();
// A signed 32-bit integer counting milliseconds only gives ~24.8 days of
@@ -129,8 +129,7 @@ mDNSs32 mDNSPlatformRawTime() {
mDNSs32 mDNSPlatformUTC() {
const auto seconds_since_epoch =
- duration_cast<seconds>(openscreen::platform::GetWallTimeSinceUnixEpoch())
- .count();
+ duration_cast<seconds>(openscreen::GetWallTimeSinceUnixEpoch()).count();
// The return type will cause overflow in early 2038. Warn future developers
// a year ahead of time.
@@ -244,7 +243,11 @@ void mDNSPlatformSetDNSConfig(mDNS* const m,
mDNSBool setsearch,
domainname* const fqdn,
DNameListElem** RegDomains,
- DNameListElem** BrowseDomains) {}
+ DNameListElem** BrowseDomains) {
+ if (fqdn) {
+ std::memset(fqdn, 0, sizeof(*fqdn));
+ }
+}
mStatus mDNSPlatformGetPrimaryInterface(mDNS* const m,
mDNSAddr* v4,
diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h
index afe04cd8653..342913fe620 100644
--- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h
+++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.h
@@ -10,7 +10,7 @@
#include "platform/api/udp_socket.h"
struct mDNS_PlatformSupport_struct {
- std::vector<openscreen::platform::UdpSocket*> sockets;
+ std::vector<openscreen::UdpSocket*> sockets;
};
#endif // OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_PLATFORM_H_
diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc
index 983fcd7810d..5441472fbca 100644
--- a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc
@@ -5,6 +5,7 @@
#include "osp/impl/internal_services.h"
#include <algorithm>
+#include <utility>
#include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h"
#include "osp/impl/mdns_responder_service.h"
@@ -21,8 +22,7 @@ constexpr char kServiceProtocol[] = "_udp";
const IPAddress kMulticastAddress{224, 0, 0, 251};
const IPAddress kMulticastIPv6Address{
// ff02::fb
- 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
+ 0xff02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00fb,
};
const uint16_t kMulticastListeningPort = 5353;
@@ -37,8 +37,7 @@ class MdnsResponderAdapterImplFactory final
}
};
-Error SetUpMulticastSocket(platform::UdpSocket* socket,
- platform::NetworkInterfaceIndex ifindex) {
+Error SetUpMulticastSocket(UdpSocket* socket, NetworkInterfaceIndex ifindex) {
const IPAddress broadcast_address =
socket->IsIPv6() ? kMulticastIPv6Address : kMulticastAddress;
@@ -60,7 +59,7 @@ int g_instance_ref_count = 0;
std::unique_ptr<ServiceListener> InternalServices::CreateListener(
const MdnsServiceListenerConfig& config,
ServiceListener::Observer* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
auto* services = ReferenceSingleton(task_runner);
auto listener =
std::make_unique<ServiceListenerImpl>(&services->mdns_service_);
@@ -74,7 +73,7 @@ std::unique_ptr<ServiceListener> InternalServices::CreateListener(
std::unique_ptr<ServicePublisher> InternalServices::CreatePublisher(
const ServicePublisher::Config& config,
ServicePublisher::Observer* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
auto* services = ReferenceSingleton(task_runner);
services->mdns_service_.SetServiceConfig(
config.hostname, config.service_instance_name,
@@ -99,11 +98,10 @@ InternalServices::InternalPlatformLinkage::~InternalPlatformLinkage() {
std::vector<MdnsPlatformService::BoundInterface>
InternalServices::InternalPlatformLinkage::RegisterInterfaces(
- const std::vector<platform::NetworkInterfaceIndex>& whitelist) {
- const std::vector<platform::InterfaceInfo> interfaces =
- platform::GetNetworkInterfaces();
+ const std::vector<NetworkInterfaceIndex>& whitelist) {
+ const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces();
const bool do_filter_using_whitelist = !whitelist.empty();
- std::vector<platform::NetworkInterfaceIndex> index_list;
+ std::vector<NetworkInterfaceIndex> index_list;
for (const auto& interface : interfaces) {
OSP_VLOG << "Found interface: " << interface;
if (do_filter_using_whitelist &&
@@ -121,28 +119,26 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces(
// Set up sockets to send and listen to mDNS multicast traffic on all
// interfaces.
std::vector<BoundInterface> result;
- for (platform::NetworkInterfaceIndex index : index_list) {
- const auto& interface =
- *std::find_if(interfaces.begin(), interfaces.end(),
- [index](const platform::InterfaceInfo& info) {
- return info.index == index;
- });
+ for (NetworkInterfaceIndex index : index_list) {
+ const auto& interface = *std::find_if(
+ interfaces.begin(), interfaces.end(),
+ [index](const InterfaceInfo& info) { return info.index == index; });
if (interface.addresses.empty()) {
continue;
}
// Pick any address for the given interface.
- const platform::IPSubnet& primary_subnet = interface.addresses.front();
+ const IPSubnet& primary_subnet = interface.addresses.front();
auto create_result =
- platform::UdpSocket::Create(parent_->task_runner_, parent_,
- IPEndpoint{{}, kMulticastListeningPort});
+ UdpSocket::Create(parent_->task_runner_, parent_,
+ IPEndpoint{{}, kMulticastListeningPort});
if (!create_result) {
OSP_LOG_ERROR << "failed to create socket for interface " << index << ": "
<< create_result.error().message();
continue;
}
- platform::UdpSocketUniquePtr socket = std::move(create_result.value());
+ std::unique_ptr<UdpSocket> socket = std::move(create_result.value());
if (!SetUpMulticastSocket(socket.get(), index).ok()) {
continue;
}
@@ -158,21 +154,20 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces(
void InternalServices::InternalPlatformLinkage::DeregisterInterfaces(
const std::vector<BoundInterface>& registered_interfaces) {
for (const auto& interface : registered_interfaces) {
- platform::UdpSocket* const socket = interface.socket;
+ UdpSocket* const socket = interface.socket;
parent_->DeregisterMdnsSocket(socket);
- const auto it =
- std::find_if(open_sockets_.begin(), open_sockets_.end(),
- [socket](const platform::UdpSocketUniquePtr& s) {
- return s.get() == socket;
- });
+ const auto it = std::find_if(open_sockets_.begin(), open_sockets_.end(),
+ [socket](const std::unique_ptr<UdpSocket>& s) {
+ return s.get() == socket;
+ });
OSP_DCHECK(it != open_sockets_.end());
open_sockets_.erase(it);
}
}
-InternalServices::InternalServices(platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner)
+InternalServices::InternalServices(ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner)
: mdns_service_(now_function,
task_runner,
kServiceName,
@@ -183,23 +178,23 @@ InternalServices::InternalServices(platform::ClockNowFunctionPtr now_function,
InternalServices::~InternalServices() = default;
-void InternalServices::RegisterMdnsSocket(platform::UdpSocket* socket) {
+void InternalServices::RegisterMdnsSocket(UdpSocket* socket) {
OSP_CHECK(g_instance) << "No listener or publisher is alive.";
// TODO(rwkeane): Hook this up to the new mDNS library once we swap out the
// mDNSResponder.
}
-void InternalServices::DeregisterMdnsSocket(platform::UdpSocket* socket) {
+void InternalServices::DeregisterMdnsSocket(UdpSocket* socket) {
// TODO(rwkeane): Hook this up to the new mDNS library once we swap out the
// mDNSResponder.
}
// static
InternalServices* InternalServices::ReferenceSingleton(
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
if (!g_instance) {
OSP_CHECK_EQ(g_instance_ref_count, 0);
- g_instance = new InternalServices(&platform::Clock::now, task_runner);
+ g_instance = new InternalServices(&Clock::now, task_runner);
}
++g_instance_ref_count;
return g_instance;
@@ -216,18 +211,17 @@ void InternalServices::DereferenceSingleton(void* instance) {
}
}
-void InternalServices::OnError(platform::UdpSocket* socket, Error error) {
+void InternalServices::OnError(UdpSocket* socket, Error error) {
OSP_LOG_ERROR << "failed to configure socket " << error.message();
this->DeregisterMdnsSocket(socket);
}
-void InternalServices::OnSendError(platform::UdpSocket* socket, Error error) {
+void InternalServices::OnSendError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
-void InternalServices::OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) {
+void InternalServices::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) {
g_instance->mdns_service_.OnRead(socket, std::move(packet));
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.h b/chromium/third_party/openscreen/src/osp/impl/internal_services.h
index 5a1f015a030..364c5963c65 100644
--- a/chromium/third_party/openscreen/src/osp/impl/internal_services.h
+++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.h
@@ -25,9 +25,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -36,22 +34,21 @@ namespace osp {
// event loop.
// TODO(btolsch): This may be renamed and/or split up once QUIC code lands and
// this use case is more concrete.
-class InternalServices : platform::UdpSocket::Client {
+class InternalServices : UdpSocket::Client {
public:
static std::unique_ptr<ServiceListener> CreateListener(
const MdnsServiceListenerConfig& config,
ServiceListener::Observer* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
static std::unique_ptr<ServicePublisher> CreatePublisher(
const ServicePublisher::Config& config,
ServicePublisher::Observer* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
// UdpSocket::Client overrides.
- void OnError(platform::UdpSocket* socket, Error error) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
private:
class InternalPlatformLinkage final : public MdnsPlatformService {
@@ -60,31 +57,29 @@ class InternalServices : platform::UdpSocket::Client {
~InternalPlatformLinkage() override;
std::vector<BoundInterface> RegisterInterfaces(
- const std::vector<platform::NetworkInterfaceIndex>& whitelist) override;
+ const std::vector<NetworkInterfaceIndex>& whitelist) override;
void DeregisterInterfaces(
const std::vector<BoundInterface>& registered_interfaces) override;
private:
InternalServices* const parent_;
- std::vector<platform::UdpSocketUniquePtr> open_sockets_;
+ std::vector<std::unique_ptr<UdpSocket>> open_sockets_;
};
// The TaskRunner provided here should live for the duration of this
// InternalService object's lifetime.
- InternalServices(platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner);
+ InternalServices(ClockNowFunctionPtr now_function, TaskRunner* task_runner);
~InternalServices() override;
- void RegisterMdnsSocket(platform::UdpSocket* socket);
- void DeregisterMdnsSocket(platform::UdpSocket* socket);
+ void RegisterMdnsSocket(UdpSocket* socket);
+ void DeregisterMdnsSocket(UdpSocket* socket);
- static InternalServices* ReferenceSingleton(
- platform::TaskRunner* task_runner);
+ static InternalServices* ReferenceSingleton(TaskRunner* task_runner);
static void DereferenceSingleton(void* instance);
MdnsResponderService mdns_service_;
- platform::TaskRunner* const task_runner_;
+ TaskRunner* const task_runner_;
OSP_DISALLOW_COPY_AND_ASSIGN(InternalServices);
};
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc
index b4e6d8b85c5..e46829b8043 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.cc
@@ -12,9 +12,9 @@ namespace openscreen {
namespace osp {
MdnsPlatformService::BoundInterface::BoundInterface(
- const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& subnet,
- platform::UdpSocket* socket)
+ const InterfaceInfo& interface_info,
+ const IPSubnet& subnet,
+ UdpSocket* socket)
: interface_info(interface_info), subnet(subnet), socket(socket) {
OSP_DCHECK(socket);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h
index 7cb6271fb5b..ec6ca31b481 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h
@@ -16,23 +16,23 @@ namespace osp {
class MdnsPlatformService {
public:
struct BoundInterface {
- BoundInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& subnet,
- platform::UdpSocket* socket);
+ BoundInterface(const InterfaceInfo& interface_info,
+ const IPSubnet& subnet,
+ UdpSocket* socket);
~BoundInterface();
bool operator==(const BoundInterface& other) const;
bool operator!=(const BoundInterface& other) const;
- platform::InterfaceInfo interface_info;
- platform::IPSubnet subnet;
- platform::UdpSocket* socket;
+ InterfaceInfo interface_info;
+ IPSubnet subnet;
+ UdpSocket* socket;
};
virtual ~MdnsPlatformService() = default;
virtual std::vector<BoundInterface> RegisterInterfaces(
- const std::vector<platform::NetworkInterfaceIndex>& whitelist) = 0;
+ const std::vector<NetworkInterfaceIndex>& whitelist) = 0;
virtual void DeregisterInterfaces(
const std::vector<BoundInterface>& registered_interfaces) = 0;
};
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc
index 347f1254269..f1aedbdcaba 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc
@@ -13,8 +13,6 @@
#include "util/logging.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace osp {
namespace {
@@ -33,8 +31,8 @@ std::string ServiceIdFromServiceInstanceName(
} // namespace
MdnsResponderService::MdnsResponderService(
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
const std::string& service_name,
const std::string& service_protocol,
std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory,
@@ -51,7 +49,7 @@ void MdnsResponderService::SetServiceConfig(
const std::string& hostname,
const std::string& instance,
uint16_t port,
- const std::vector<platform::NetworkInterfaceIndex> whitelist,
+ const std::vector<NetworkInterfaceIndex> whitelist,
const std::map<std::string, std::string>& txt_data) {
OSP_DCHECK(!hostname.empty());
OSP_DCHECK(!instance.empty());
@@ -63,9 +61,9 @@ void MdnsResponderService::SetServiceConfig(
service_txt_data_ = txt_data;
}
-void MdnsResponderService::OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::OnRead");
+void MdnsResponderService::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet) {
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::OnRead");
if (!mdns_responder_) {
return;
}
@@ -74,12 +72,11 @@ void MdnsResponderService::OnRead(platform::UdpSocket* socket,
HandleMdnsEvents();
}
-void MdnsResponderService::OnSendError(platform::UdpSocket* socket,
- Error error) {
+void MdnsResponderService::OnSendError(UdpSocket* socket, Error error) {
mdns_responder_->OnSendError(socket, std::move(error));
}
-void MdnsResponderService::OnError(platform::UdpSocket* socket, Error error) {
+void MdnsResponderService::OnError(UdpSocket* socket, Error error) {
mdns_responder_->OnError(socket, std::move(error));
}
@@ -214,7 +211,7 @@ bool MdnsResponderService::NetworkScopedDomainNameComparator::operator()(
}
void MdnsResponderService::HandleMdnsEvents() {
- TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::HandleMdnsEvents");
+ TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::HandleMdnsEvents");
// NOTE: In the common case, we will get a single combined packet for
// PTR/SRV/TXT/A and then no other packets. If we don't loop here, we would
// start SRV/TXT queries based on the PTR response, but never check for events
@@ -334,7 +331,7 @@ void MdnsResponderService::StopListening() {
}
network_scoped_domain_to_host_.clear();
for (const auto& service : service_by_name_) {
- platform::UdpSocket* const socket = service.second->ptr_socket;
+ UdpSocket* const socket = service.second->ptr_socket;
mdns_responder_->StopSrvQuery(socket, service.first);
mdns_responder_->StopTxtQuery(socket, service.first);
}
@@ -429,7 +426,7 @@ bool MdnsResponderService::HandlePtrEvent(
InstanceNameSet* modified_instance_names) {
bool events_possible = false;
const auto& instance_name = ptr_event.service_instance;
- platform::UdpSocket* const socket = ptr_event.header.socket;
+ UdpSocket* const socket = ptr_event.header.socket;
auto entry = service_by_name_.find(ptr_event.service_instance);
switch (ptr_event.header.response_type) {
case QueryEventHeader::Type::kAddedNoCache:
@@ -481,7 +478,7 @@ bool MdnsResponderService::HandleSrvEvent(
bool events_possible = false;
auto& domain_name = srv_event.domain_name;
const auto& instance_name = srv_event.service_instance;
- platform::UdpSocket* const socket = srv_event.header.socket;
+ UdpSocket* const socket = srv_event.header.socket;
auto entry = service_by_name_.find(srv_event.service_instance);
if (entry == service_by_name_.end())
return events_possible;
@@ -568,7 +565,7 @@ bool MdnsResponderService::HandleTxtEvent(
}
bool MdnsResponderService::HandleAddressEvent(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
QueryEventHeader::Type response_type,
const DomainName& domain_name,
bool a_event,
@@ -619,14 +616,14 @@ bool MdnsResponderService::HandleAaaaEvent(
}
MdnsResponderService::HostInfo* MdnsResponderService::AddOrGetHostInfo(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
return &network_scoped_domain_to_host_[NetworkScopedDomainName{socket,
domain_name}];
}
MdnsResponderService::HostInfo* MdnsResponderService::GetHostInfo(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
auto kv = network_scoped_domain_to_host_.find(
NetworkScopedDomainName{socket, domain_name});
@@ -642,16 +639,15 @@ bool MdnsResponderService::IsServiceReady(const ServiceInstance& instance,
!instance.txt_info.empty() && (host->v4_address || host->v6_address));
}
-platform::NetworkInterfaceIndex
-MdnsResponderService::GetNetworkInterfaceIndexFromSocket(
- const platform::UdpSocket* socket) const {
+NetworkInterfaceIndex MdnsResponderService::GetNetworkInterfaceIndexFromSocket(
+ const UdpSocket* socket) const {
auto it = std::find_if(
bound_interfaces_.begin(), bound_interfaces_.end(),
[socket](const MdnsPlatformService::BoundInterface& interface) {
return interface.socket == socket;
});
if (it == bound_interfaces_.end())
- return platform::kInvalidNetworkInterfaceIndex;
+ return kInvalidNetworkInterfaceIndex;
return it->interface_info.index;
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h
index c6ec126f523..a010f5a6065 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h
@@ -34,29 +34,27 @@ class MdnsResponderAdapterFactory {
class MdnsResponderService : public ServiceListenerImpl::Delegate,
public ServicePublisherImpl::Delegate,
- public platform::UdpSocket::Client {
+ public UdpSocket::Client {
public:
MdnsResponderService(
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner,
const std::string& service_name,
const std::string& service_protocol,
std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory,
std::unique_ptr<MdnsPlatformService> platform);
virtual ~MdnsResponderService() override;
- void SetServiceConfig(
- const std::string& hostname,
- const std::string& instance,
- uint16_t port,
- const std::vector<platform::NetworkInterfaceIndex> whitelist,
- const std::map<std::string, std::string>& txt_data);
+ void SetServiceConfig(const std::string& hostname,
+ const std::string& instance,
+ uint16_t port,
+ const std::vector<NetworkInterfaceIndex> whitelist,
+ const std::map<std::string, std::string>& txt_data);
// UdpSocket::Client overrides.
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnError(platform::UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnError(UdpSocket* socket, Error error) override;
// ServiceListenerImpl::Delegate overrides.
void StartListener() override;
@@ -98,7 +96,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
// NOTE: service_instance implicit in map key.
struct ServiceInstance {
- platform::UdpSocket* ptr_socket = nullptr;
+ UdpSocket* ptr_socket = nullptr;
DomainName domain_name;
uint16_t port = 0;
bool has_ptr_record = false;
@@ -116,7 +114,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
};
struct NetworkScopedDomainName {
- platform::UdpSocket* socket;
+ UdpSocket* socket;
DomainName domain_name;
};
@@ -144,7 +142,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
InstanceNameSet* modified_instance_names);
bool HandleTxtEvent(const TxtEvent& txt_event,
InstanceNameSet* modified_instance_names);
- bool HandleAddressEvent(platform::UdpSocket* socket,
+ bool HandleAddressEvent(UdpSocket* socket,
QueryEventHeader::Type response_type,
const DomainName& domain_name,
bool a_event,
@@ -155,13 +153,11 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
bool HandleAaaaEvent(const AaaaEvent& aaaa_event,
InstanceNameSet* modified_instance_names);
- HostInfo* AddOrGetHostInfo(platform::UdpSocket* socket,
- const DomainName& domain_name);
- HostInfo* GetHostInfo(platform::UdpSocket* socket,
- const DomainName& domain_name);
+ HostInfo* AddOrGetHostInfo(UdpSocket* socket, const DomainName& domain_name);
+ HostInfo* GetHostInfo(UdpSocket* socket, const DomainName& domain_name);
bool IsServiceReady(const ServiceInstance& instance, HostInfo* host) const;
- platform::NetworkInterfaceIndex GetNetworkInterfaceIndexFromSocket(
- const platform::UdpSocket* socket) const;
+ NetworkInterfaceIndex GetNetworkInterfaceIndexFromSocket(
+ const UdpSocket* socket) const;
// Runs background tasks to manage the internal mDNS state.
void RunBackgroundTasks();
@@ -175,7 +171,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
std::string service_hostname_;
std::string service_instance_name_;
uint16_t service_port_;
- std::vector<platform::NetworkInterfaceIndex> interface_index_whitelist_;
+ std::vector<NetworkInterfaceIndex> interface_index_whitelist_;
std::map<std::string, std::string> service_txt_data_;
std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory_;
@@ -199,7 +195,7 @@ class MdnsResponderService : public ServiceListenerImpl::Delegate,
std::map<std::string, ServiceInfo> receiver_info_;
- platform::TaskRunner* const task_runner_;
+ TaskRunner* const task_runner_;
// Scheduled to run periodic background tasks.
Alarm background_tasks_alarm_;
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc
index fe8e1239dc4..5a9891acf5b 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc
@@ -24,12 +24,12 @@ namespace osp {
class TestingMdnsResponderService final : public MdnsResponderService {
public:
TestingMdnsResponderService(
- platform::FakeTaskRunner* task_runner,
+ FakeTaskRunner* task_runner,
const std::string& service_name,
const std::string& service_protocol,
std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory,
std::unique_ptr<MdnsPlatformService> platform_service)
- : MdnsResponderService(&platform::FakeClock::now,
+ : MdnsResponderService(&FakeClock::now,
task_runner,
service_name,
service_protocol,
@@ -172,10 +172,10 @@ class MockServicePublisherObserver final : public ServicePublisher::Observer {
MOCK_METHOD1(OnMetrics, void(ServicePublisher::Metrics));
};
-platform::UdpSocket* const kDefaultSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(16));
-platform::UdpSocket* const kSecondSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(24));
+UdpSocket* const kDefaultSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16));
+UdpSocket* const kSecondSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24));
class MdnsResponderServiceTest : public ::testing::Test {
protected:
@@ -184,8 +184,8 @@ class MdnsResponderServiceTest : public ::testing::Test {
std::make_unique<FakeMdnsResponderAdapterFactory>();
auto wrapper_factory = std::make_unique<WrapperMdnsResponderAdapterFactory>(
mdns_responder_factory_.get());
- clock_ = std::make_unique<platform::FakeClock>(platform::Clock::now());
- task_runner_ = std::make_unique<platform::FakeTaskRunner>(clock_.get());
+ clock_ = std::make_unique<FakeClock>(Clock::now());
+ task_runner_ = std::make_unique<FakeTaskRunner>(clock_.get());
auto platform_service = std::make_unique<FakeMdnsPlatformService>();
fake_platform_service_ = platform_service.get();
fake_platform_service_->set_interfaces(bound_interfaces_);
@@ -202,8 +202,8 @@ class MdnsResponderServiceTest : public ::testing::Test {
&publisher_observer_, mdns_service_.get());
}
- std::unique_ptr<platform::FakeClock> clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
MockServiceListenerObserver observer_;
FakeMdnsPlatformService* fake_platform_service_;
std::unique_ptr<FakeMdnsResponderAdapterFactory> mdns_responder_factory_;
@@ -213,22 +213,22 @@ class MdnsResponderServiceTest : public ::testing::Test {
std::unique_ptr<ServicePublisherImpl> service_publisher_;
const uint8_t default_mac_[6] = {0, 11, 22, 33, 44, 55};
const uint8_t second_mac_[6] = {55, 33, 22, 33, 44, 77};
- const platform::IPSubnet default_subnet_{IPAddress{192, 168, 3, 2}, 24};
- const platform::IPSubnet second_subnet_{IPAddress{10, 0, 0, 3}, 24};
+ const IPSubnet default_subnet_{IPAddress{192, 168, 3, 2}, 24};
+ const IPSubnet second_subnet_{IPAddress{10, 0, 0, 3}, 24};
std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{
MdnsPlatformService::BoundInterface{
- platform::InterfaceInfo{1,
- default_mac_,
- "eth0",
- platform::InterfaceInfo::Type::kEthernet,
- {default_subnet_}},
+ InterfaceInfo{1,
+ default_mac_,
+ "eth0",
+ InterfaceInfo::Type::kEthernet,
+ {default_subnet_}},
default_subnet_, kDefaultSocket},
MdnsPlatformService::BoundInterface{
- platform::InterfaceInfo{2,
- second_mac_,
- "eth1",
- platform::InterfaceInfo::Type::kEthernet,
- {second_subnet_}},
+ InterfaceInfo{2,
+ second_mac_,
+ "eth1",
+ InterfaceInfo::Type::kEthernet,
+ {second_subnet_}},
second_subnet_, kSecondSocket},
};
};
@@ -284,10 +284,9 @@ TEST_F(MdnsResponderServiceTest, BasicServiceStates) {
TEST_F(MdnsResponderServiceTest, NetworkNetworkInterfaceIndex) {
constexpr uint8_t mac[6] = {12, 34, 56, 78, 90};
- const platform::IPSubnet subnet{IPAddress{10, 0, 0, 2}, 24};
+ const IPSubnet subnet{IPAddress{10, 0, 0, 2}, 24};
bound_interfaces_.emplace_back(
- platform::InterfaceInfo{
- 2, mac, "wlan0", platform::InterfaceInfo::Type::kWifi, {subnet}},
+ InterfaceInfo{2, mac, "wlan0", InterfaceInfo::Type::kWifi, {subnet}},
subnet, kSecondSocket);
fake_platform_service_->set_interfaces(bound_interfaces_);
EXPECT_CALL(observer_, OnStarted());
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc
index a94d81081ac..cae4a341ec7 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc
@@ -8,9 +8,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -18,7 +16,7 @@ namespace osp {
std::unique_ptr<ServiceListener> MdnsServiceListenerFactory::Create(
const MdnsServiceListenerConfig& config,
ServiceListener::Observer* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
return InternalServices::CreateListener(config, observer, task_runner);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc
index cbd8d122465..f055e772afb 100644
--- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc
@@ -8,9 +8,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -18,7 +16,7 @@ namespace osp {
std::unique_ptr<ServicePublisher> MdnsServicePublisherFactory::Create(
const ServicePublisher::Config& config,
ServicePublisher::Observer* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
return InternalServices::CreatePublisher(config, observer, task_runner);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc
index 1bd6f408a34..9220b63b724 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc
@@ -181,13 +181,12 @@ void ConnectionManager::RemoveConnection(Connection* connection) {
// TODO(jophba): refine the RegisterWatch/OnStreamMessage API. We
// should add a layer between the message logic and the parse/dispatch
// logic, and remove the CBOR information from ConnectionManager.
-ErrorOr<size_t> ConnectionManager::OnStreamMessage(
- uint64_t endpoint_id,
- uint64_t connection_id,
- msgs::Type message_type,
- const uint8_t* buffer,
- size_t buffer_size,
- platform::Clock::time_point now) {
+ErrorOr<size_t> ConnectionManager::OnStreamMessage(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size,
+ Clock::time_point now) {
switch (message_type) {
case msgs::Type::kPresentationConnectionMessage: {
msgs::PresentationConnectionMessage message;
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc
index 6f611dda25a..f4c829a4117 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc
@@ -58,12 +58,11 @@ class MockConnectRequest final
class ConnectionTest : public ::testing::Test {
public:
ConnectionTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(std::chrono::milliseconds(1298424)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ = std::make_unique<FakeClock>(
+ Clock::time_point(std::chrono::milliseconds(1298424)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
controller_connection_manager_ = std::make_unique<ConnectionManager>(
quic_bridge_->controller_demuxer.get());
receiver_connection_manager_ = std::make_unique<ConnectionManager>(
@@ -89,8 +88,8 @@ class ConnectionTest : public ::testing::Test {
return response;
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
std::unique_ptr<FakeQuicBridge> quic_bridge_;
std::unique_ptr<ConnectionManager> controller_connection_manager_;
std::unique_ptr<ConnectionManager> receiver_connection_manager_;
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc
index 5a54cb2b99c..94ee4e2d1f4 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller.cc
@@ -418,7 +418,7 @@ void swap(Controller::ConnectRequest& a, Controller::ConnectRequest& b) {
swap(a.controller_, b.controller_);
}
-Controller::Controller(platform::ClockNowFunctionPtr now_function) {
+Controller::Controller(ClockNowFunctionPtr now_function) {
availability_requester_ =
std::make_unique<UrlAvailabilityRequester>(now_function);
connection_manager_ =
@@ -621,7 +621,7 @@ class Controller::TerminationListener final
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) override;
+ Clock::time_point now) override;
private:
Controller* const controller_;
@@ -650,7 +650,7 @@ ErrorOr<size_t> Controller::TerminationListener::OnStreamMessage(
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) {
+ Clock::time_point now) {
OSP_CHECK_EQ(static_cast<int>(msgs::Type::kPresentationTerminationEvent),
static_cast<int>(message_type));
msgs::PresentationTerminationEvent event;
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc
index 5b32c12f458..bad4c66616a 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc
@@ -77,12 +77,11 @@ class MockRequestDelegate final : public RequestDelegate {
class ControllerTest : public ::testing::Test {
public:
ControllerTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(seconds(11111)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ =
+ std::make_unique<FakeClock>(Clock::time_point(seconds(11111)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
receiver_info1 = {
"service-id1", "lucas-auer", 1, quic_bridge_->kReceiverEndpoint, {}};
}
@@ -94,7 +93,7 @@ class ControllerTest : public ::testing::Test {
NetworkServiceManager::Create(std::move(service_listener), nullptr,
std::move(quic_bridge_->quic_client),
std::move(quic_bridge_->quic_server));
- controller_ = std::make_unique<Controller>(platform::FakeClock::now);
+ controller_ = std::make_unique<Controller>(FakeClock::now);
ON_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_))
.WillByDefault(
Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
@@ -117,16 +116,15 @@ class ControllerTest : public ::testing::Test {
ssize_t decode_result = -1;
msgs::Type msg_type;
EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(
- Invoke([request, &msg_type, &decode_result](
- uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- msg_type = message_type;
- decode_result = msgs::DecodePresentationUrlAvailabilityRequest(
- buffer, buffer_size, request);
- return decode_result;
- }));
+ .WillOnce(Invoke([request, &msg_type, &decode_result](
+ uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ msg_type = message_type;
+ decode_result = msgs::DecodePresentationUrlAvailabilityRequest(
+ buffer, buffer_size, request);
+ return decode_result;
+ }));
quic_bridge_->RunTasksUntilIdle();
ASSERT_EQ(msg_type, msgs::Type::kPresentationUrlAvailabilityRequest);
ASSERT_GT(decode_result, 0);
@@ -206,16 +204,15 @@ class ControllerTest : public ::testing::Test {
ssize_t decode_result = -1;
msgs::Type msg_type;
EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(
- Invoke([request, &msg_type, &decode_result](
- uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- msg_type = message_type;
- decode_result = msgs::DecodePresentationConnectionCloseRequest(
- buffer, buffer_size, request);
- return decode_result;
- }));
+ .WillOnce(Invoke([request, &msg_type, &decode_result](
+ uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ msg_type = message_type;
+ decode_result = msgs::DecodePresentationConnectionCloseRequest(
+ buffer, buffer_size, request);
+ return decode_result;
+ }));
connection->Close(Connection::CloseReason::kClosed);
EXPECT_EQ(connection->state(), Connection::State::kClosed);
quic_bridge_->RunTasksUntilIdle();
@@ -264,16 +261,15 @@ class ControllerTest : public ::testing::Test {
msgs::PresentationStartRequest request;
msgs::Type msg_type;
EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(
- Invoke([&request, &msg_type](
- uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- msg_type = message_type;
- ssize_t result = msgs::DecodePresentationStartRequest(
- buffer, buffer_size, &request);
- return result;
- }));
+ .WillOnce(Invoke([&request, &msg_type](
+ uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ msg_type = message_type;
+ ssize_t result = msgs::DecodePresentationStartRequest(
+ buffer, buffer_size, &request);
+ return result;
+ }));
Controller::ConnectRequest connect_request = controller_->StartPresentation(
"https://example.com/receiver.html", receiver_info1.service_id,
&mock_request_delegate, mock_connection_delegate);
@@ -297,8 +293,8 @@ class ControllerTest : public ::testing::Test {
ASSERT_TRUE(*connection);
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
MessageDemuxer::MessageWatch availability_watch_;
MockMessageCallback mock_callback_;
std::unique_ptr<FakeQuicBridge> quic_bridge_;
@@ -413,16 +409,15 @@ TEST_F(ControllerTest, TerminatePresentationFromController) {
msgs::PresentationTerminationRequest termination_request;
msgs::Type msg_type;
EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(
- Invoke([&termination_request, &msg_type](
- uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- msg_type = message_type;
- ssize_t result = msgs::DecodePresentationTerminationRequest(
- buffer, buffer_size, &termination_request);
- return result;
- }));
+ .WillOnce(Invoke([&termination_request, &msg_type](
+ uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ msg_type = message_type;
+ ssize_t result = msgs::DecodePresentationTerminationRequest(
+ buffer, buffer_size, &termination_request);
+ return result;
+ }));
connection->Terminate(TerminationReason::kControllerTerminateCalled);
quic_bridge_->RunTasksUntilIdle();
@@ -505,16 +500,15 @@ TEST_F(ControllerTest, Reconnect) {
ssize_t decode_result = -1;
msgs::Type msg_type;
EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(
- Invoke([&open_request, &msg_type, &decode_result](
- uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- msg_type = message_type;
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &open_request);
- return decode_result;
- }));
+ .WillOnce(Invoke([&open_request, &msg_type, &decode_result](
+ uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ msg_type = message_type;
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &open_request);
+ return decode_result;
+ }));
quic_bridge_->RunTasksUntilIdle();
ASSERT_FALSE(connection);
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc
index 17babae038b..eb60e0bb085 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc
@@ -16,8 +16,6 @@
#include "util/logging.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace osp {
namespace {
@@ -107,11 +105,11 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) {
- TRACE_SCOPED(TraceCategory::Presentation, "Receiver::OnStreamMessage");
+ Clock::time_point now) {
+ TRACE_SCOPED(TraceCategory::kPresentation, "Receiver::OnStreamMessage");
switch (message_type) {
case msgs::Type::kPresentationUrlAvailabilityRequest: {
- TRACE_SCOPED(TraceCategory::Presentation,
+ TRACE_SCOPED(TraceCategory::kPresentation,
"kPresentationUrlAvailabilityRequest");
OSP_VLOG << "got presentation-url-availability-request";
msgs::PresentationUrlAvailabilityRequest request;
@@ -137,7 +135,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
}
case msgs::Type::kPresentationStartRequest: {
- TRACE_SCOPED(TraceCategory::Presentation, "kPresentationStartRequest");
+ TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationStartRequest");
OSP_VLOG << "got presentation-start-request";
msgs::PresentationStartRequest request;
const ssize_t result =
@@ -198,7 +196,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
}
case msgs::Type::kPresentationConnectionOpenRequest: {
- TRACE_SCOPED(TraceCategory::Presentation,
+ TRACE_SCOPED(TraceCategory::kPresentation,
"kPresentationConnectionOpenRequest");
OSP_VLOG << "Got a presentation-connection-open-request";
msgs::PresentationConnectionOpenRequest request;
@@ -266,7 +264,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
}
case msgs::Type::kPresentationTerminationRequest: {
- TRACE_SCOPED(TraceCategory::Presentation,
+ TRACE_SCOPED(TraceCategory::kPresentation,
"kPresentationTerminationRequest");
OSP_VLOG << "got presentation-termination-request";
msgs::PresentationTerminationRequest request;
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc
index 098df0d3436..173b9fd561a 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc
@@ -69,12 +69,11 @@ class MockReceiverDelegate final : public ReceiverDelegate {
class PresentationReceiverTest : public ::testing::Test {
public:
PresentationReceiverTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(std::chrono::milliseconds(1298424)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ = std::make_unique<FakeClock>(
+ Clock::time_point(std::chrono::milliseconds(1298424)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
}
protected:
@@ -103,8 +102,8 @@ class PresentationReceiverTest : public ::testing::Test {
NetworkServiceManager::Dispose();
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
const std::string url1_{"https://www.example.com/receiver.html"};
std::unique_ptr<FakeQuicBridge> quic_bridge_;
MockReceiverDelegate mock_receiver_delegate_;
@@ -142,14 +141,14 @@ TEST_F(PresentationReceiverTest, QueryAvailability) {
msgs::PresentationUrlAvailabilityResponse response;
EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(Invoke([&response](uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type,
- const uint8_t* buffer, size_t buffer_size,
- platform::Clock::time_point now) {
- ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse(
- buffer, buffer_size, &response);
- return result;
- }));
+ .WillOnce(
+ Invoke([&response](uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ ssize_t result = msgs::DecodePresentationUrlAvailabilityResponse(
+ buffer, buffer_size, &response);
+ return result;
+ }));
quic_bridge_->RunTasksUntilIdle();
EXPECT_EQ(request.request_id, response.request_id);
EXPECT_EQ(
@@ -190,14 +189,14 @@ TEST_F(PresentationReceiverTest, StartPresentation) {
ResponseResult::kSuccess);
msgs::PresentationStartResponse response;
EXPECT_CALL(mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(Invoke([&response](uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type,
- const uint8_t* buffer, size_t buffer_size,
- platform::Clock::time_point now) {
- ssize_t result = msgs::DecodePresentationStartResponse(
- buffer, buffer_size, &response);
- return result;
- }));
+ .WillOnce(
+ Invoke([&response](uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ ssize_t result = msgs::DecodePresentationStartResponse(
+ buffer, buffer_size, &response);
+ return result;
+ }));
quic_bridge_->RunTasksUntilIdle();
EXPECT_EQ(msgs::Result::kSuccess, response.result);
EXPECT_EQ(connection.connection_id(), response.connection_id);
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc
index 9213447c26e..b5e3bae7730 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.cc
@@ -12,7 +12,6 @@
#include "osp/public/network_service_manager.h"
#include "util/logging.h"
-using openscreen::platform::Clock;
using std::chrono::seconds;
namespace openscreen {
@@ -48,7 +47,7 @@ uint64_t GetNextRequestId(const uint64_t endpoint_id) {
} // namespace
UrlAvailabilityRequester::UrlAvailabilityRequester(
- platform::ClockNowFunctionPtr now_function)
+ ClockNowFunctionPtr now_function)
: now_function_(now_function) {
OSP_DCHECK(now_function_);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h
index 82f5bb6bca6..dd0309a5970 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester.h
@@ -29,7 +29,7 @@ namespace osp {
// given URL.
class UrlAvailabilityRequester {
public:
- explicit UrlAvailabilityRequester(platform::ClockNowFunctionPtr now_function);
+ explicit UrlAvailabilityRequester(ClockNowFunctionPtr now_function);
~UrlAvailabilityRequester();
// Adds a persistent availability request for |urls| to all known receivers.
@@ -63,7 +63,7 @@ class UrlAvailabilityRequester {
// Ensures that all open availability watches (to all receivers) that are
// about to expire are refreshed by sending a new request with the same URLs.
// Returns the time point at which this should next be scheduled to run.
- platform::Clock::time_point RefreshWatches();
+ Clock::time_point RefreshWatches();
private:
// Handles Presentation API URL availability requests and watches for one
@@ -82,7 +82,7 @@ class UrlAvailabilityRequester {
};
struct Watch {
- platform::Clock::time_point deadline;
+ Clock::time_point deadline;
std::vector<std::string> urls;
};
@@ -97,7 +97,7 @@ class UrlAvailabilityRequester {
void RequestUrlAvailabilities(std::vector<std::string> urls);
ErrorOr<uint64_t> SendRequest(uint64_t request_id,
const std::vector<std::string>& urls);
- platform::Clock::time_point RefreshWatches(platform::Clock::time_point now);
+ Clock::time_point RefreshWatches(Clock::time_point now);
Error::Code UpdateAvailabilities(
const std::vector<std::string>& urls,
const std::vector<msgs::UrlAvailability>& availabilities);
@@ -117,7 +117,7 @@ class UrlAvailabilityRequester {
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) override;
+ Clock::time_point now) override;
UrlAvailabilityRequester* const listener;
@@ -138,7 +138,7 @@ class UrlAvailabilityRequester {
std::map<std::string, msgs::UrlAvailability> known_availability_by_url;
};
- const platform::ClockNowFunctionPtr now_function_;
+ const ClockNowFunctionPtr now_function_;
std::map<std::string, std::vector<ReceiverObserver*>> observers_by_url_;
diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc
index a0e19639b84..578c861220a 100644
--- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc
@@ -46,12 +46,11 @@ class MockReceiverObserver : public ReceiverObserver {
class UrlAvailabilityRequesterTest : public Test {
public:
UrlAvailabilityRequesterTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(milliseconds(1298424)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ =
+ std::make_unique<FakeClock>(Clock::time_point(milliseconds(1298424)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
info1_ = {service_id_, friendly_name_, 1, quic_bridge_->kReceiverEndpoint};
}
@@ -86,16 +85,16 @@ class UrlAvailabilityRequesterTest : public Test {
void ExpectStreamMessage(MockMessageCallback* mock_callback,
msgs::PresentationUrlAvailabilityRequest* request) {
EXPECT_CALL(*mock_callback, OnStreamMessage(_, _, _, _, _, _))
- .WillOnce(Invoke([request](uint64_t endpoint_id, uint64_t cid,
- msgs::Type message_type,
- const uint8_t* buffer, size_t buffer_size,
- platform::Clock::time_point now) {
- ssize_t request_result_size =
- msgs::DecodePresentationUrlAvailabilityRequest(
- buffer, buffer_size, request);
- OSP_DCHECK_GT(request_result_size, 0);
- return request_result_size;
- }));
+ .WillOnce(
+ Invoke([request](uint64_t endpoint_id, uint64_t cid,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ ssize_t request_result_size =
+ msgs::DecodePresentationUrlAvailabilityRequest(
+ buffer, buffer_size, request);
+ OSP_DCHECK_GT(request_result_size, 0);
+ return request_result_size;
+ }));
}
void SendAvailabilityResponse(
@@ -126,12 +125,12 @@ class UrlAvailabilityRequesterTest : public Test {
stream->Write(buffer.data(), buffer.size());
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
MockMessageCallback mock_callback_;
MessageDemuxer::MessageWatch availability_watch_;
std::unique_ptr<FakeQuicBridge> quic_bridge_;
- UrlAvailabilityRequester listener_{platform::FakeClock::now};
+ UrlAvailabilityRequester listener_{FakeClock::now};
std::string url1_{"https://example.com/foo.html"};
std::string url2_{"https://example.com/bar.html"};
diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc
index 90b14590396..c4a72b7970b 100644
--- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc
@@ -20,10 +20,10 @@ std::unique_ptr<ProtocolConnectionClient>
ProtocolConnectionClientFactory::Create(
MessageDemuxer* demuxer,
ProtocolConnectionServiceObserver* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
return std::make_unique<QuicClient>(
demuxer, std::make_unique<QuicConnectionFactoryImpl>(task_runner),
- observer, &platform::Clock::now, task_runner);
+ observer, &Clock::now, task_runner);
}
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc
index e0a3b03455d..2984429c5b8 100644
--- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc
@@ -21,10 +21,10 @@ ProtocolConnectionServerFactory::Create(
const ServerConfig& config,
MessageDemuxer* demuxer,
ProtocolConnectionServer::Observer* observer,
- platform::TaskRunner* task_runner) {
+ TaskRunner* task_runner) {
return std::make_unique<QuicServer>(
config, demuxer, std::make_unique<QuicConnectionFactoryImpl>(task_runner),
- observer, &platform::Clock::now, task_runner);
+ observer, &Clock::now, task_runner);
}
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn b/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn
index 9b0c80288ec..221af394bdb 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/BUILD.gn
@@ -16,6 +16,7 @@ source_set("quic") {
deps = [
"../../../platform",
+ "../../../third_party/abseil",
"../../../util",
"../../public",
]
@@ -34,6 +35,7 @@ source_set("test_support") {
deps = [
"../../../platform",
+ "../../../third_party/abseil",
"../../../third_party/googletest:gmock",
"../../../util",
"../../msgs",
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc
index 41303ad7629..7d0a310301e 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc
@@ -19,8 +19,8 @@ QuicClient::QuicClient(
MessageDemuxer* demuxer,
std::unique_ptr<QuicConnectionFactory> connection_factory,
ProtocolConnectionServiceObserver* observer,
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner)
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner)
: ProtocolConnectionClient(demuxer, observer),
connection_factory_(std::move(connection_factory)),
cleanup_alarm_(now_function, task_runner) {}
@@ -63,8 +63,7 @@ void QuicClient::Cleanup() {
}
delete_connections_.clear();
- constexpr platform::Clock::duration kQuicCleanupPeriod =
- std::chrono::milliseconds(500);
+ constexpr Clock::duration kQuicCleanupPeriod = std::chrono::milliseconds(500);
if (state_ != State::kStopped) {
cleanup_alarm_.ScheduleFromNow([this] { Cleanup(); }, kQuicCleanupPeriod);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h
index 50fb5503971..548cbff51f5 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h
@@ -43,8 +43,8 @@ class QuicClient final : public ProtocolConnectionClient,
QuicClient(MessageDemuxer* demuxer,
std::unique_ptr<QuicConnectionFactory> connection_factory,
ProtocolConnectionServiceObserver* observer,
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner);
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner);
~QuicClient() override;
// ProtocolConnectionClient overrides.
@@ -103,7 +103,7 @@ class QuicClient final : public ProtocolConnectionClient,
// Maps an IPEndpoint to a generated endpoint ID. This is used to insulate
// callers from post-handshake changes to a connections actual peer endpoint.
- std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_;
+ std::map<IPEndpoint, uint64_t> endpoint_map_;
// Value that will be used for the next new endpoint in a Connect call.
uint64_t next_endpoint_id_ = 0;
@@ -119,8 +119,7 @@ class QuicClient final : public ProtocolConnectionClient,
// Maps endpoint addresses to data about connections that haven't successfully
// completed the QUIC handshake.
- std::map<IPEndpoint, PendingConnectionData, IPEndpointComparator>
- pending_connections_;
+ std::map<IPEndpoint, PendingConnectionData> pending_connections_;
// Maps endpoint IDs to data about connections that have successfully
// completed the QUIC handshake.
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc
index 954fa11ce5d..71a74ab4165 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc
@@ -60,12 +60,11 @@ class ConnectionCallback final
class QuicClientTest : public ::testing::Test {
public:
QuicClientTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(std::chrono::milliseconds(1298424)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ = std::make_unique<FakeClock>(
+ Clock::time_point(std::chrono::milliseconds(1298424)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
}
protected:
@@ -100,17 +99,16 @@ class QuicClientTest : public ::testing::Test {
mock_message_callback,
OnStreamMessage(0, connection->id(),
msgs::Type::kPresentationConnectionMessage, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_message](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionMessage(
- buffer, buffer_size, &received_message);
- if (decode_result < 0)
- return ErrorOr<size_t>(Error::Code::kCborParsing);
- return ErrorOr<size_t>(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_message](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionMessage(
+ buffer, buffer_size, &received_message);
+ if (decode_result < 0)
+ return ErrorOr<size_t>(Error::Code::kCborParsing);
+ return ErrorOr<size_t>(decode_result);
+ }));
quic_bridge_->RunTasksUntilIdle();
ASSERT_GT(decode_result, 0);
@@ -121,8 +119,8 @@ class QuicClientTest : public ::testing::Test {
EXPECT_EQ(received_message.message.str, message.message.str);
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
std::unique_ptr<FakeQuicBridge> quic_bridge_;
QuicClient* client_;
};
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h
index 2f1cb9360cf..e00e25a044d 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h
@@ -37,7 +37,7 @@ class QuicStream {
uint64_t id_;
};
-class QuicConnection : public platform::UdpSocket::Client {
+class QuicConnection : public UdpSocket::Client {
public:
class Delegate {
public:
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h
index 2e3e10b38d0..f396419d4d6 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h
@@ -18,7 +18,7 @@ namespace osp {
// This interface provides a way to make new QUIC connections to endpoints. It
// also provides a way to receive incoming QUIC connections (as a server).
-class QuicConnectionFactory : public platform::UdpSocket::Client {
+class QuicConnectionFactory : public UdpSocket::Client {
public:
class ServerDelegate {
public:
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc
index 0ecc17e2aea..6514b79119c 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc
@@ -18,13 +18,12 @@
#include "util/logging.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace osp {
+
class QuicTaskRunner final : public ::base::TaskRunner {
public:
- explicit QuicTaskRunner(platform::TaskRunner* task_runner);
+ explicit QuicTaskRunner(openscreen::TaskRunner* task_runner);
~QuicTaskRunner() override;
void RunTasks();
@@ -37,11 +36,12 @@ class QuicTaskRunner final : public ::base::TaskRunner {
bool RunsTasksInCurrentSequence() const override;
private:
- platform::TaskRunner* const task_runner_;
+ openscreen::TaskRunner* const task_runner_;
};
-QuicTaskRunner::QuicTaskRunner(platform::TaskRunner* task_runner)
+QuicTaskRunner::QuicTaskRunner(openscreen::TaskRunner* task_runner)
: task_runner_(task_runner) {}
+
QuicTaskRunner::~QuicTaskRunner() = default;
void QuicTaskRunner::RunTasks() {}
@@ -49,8 +49,7 @@ void QuicTaskRunner::RunTasks() {}
bool QuicTaskRunner::PostDelayedTask(const ::base::Location& whence,
::base::OnceClosure task,
::base::TimeDelta delay) {
- platform::Clock::duration wait =
- platform::Clock::duration(delay.InMilliseconds());
+ Clock::duration wait = Clock::duration(delay.InMilliseconds());
task_runner_->PostTaskWithDelay(
[closure = std::move(task)]() mutable { std::move(closure).Run(); },
wait);
@@ -61,8 +60,7 @@ bool QuicTaskRunner::RunsTasksInCurrentSequence() const {
return true;
}
-QuicConnectionFactoryImpl::QuicConnectionFactoryImpl(
- platform::TaskRunner* task_runner)
+QuicConnectionFactoryImpl::QuicConnectionFactoryImpl(TaskRunner* task_runner)
: task_runner_(task_runner) {
quic_task_runner_ = ::base::MakeRefCounted<QuicTaskRunner>(task_runner);
alarm_factory_ = std::make_unique<::net::QuicChromiumAlarmFactory>(
@@ -90,34 +88,31 @@ void QuicConnectionFactoryImpl::SetServerDelegate(
// create/bind errors occur. Maybe return an Error immediately, and undo
// partial progress (i.e. "unwatch" all the sockets and call
// sockets_.clear() to close the sockets)?
- auto create_result =
- platform::UdpSocket::Create(task_runner_, this, endpoint);
+ auto create_result = UdpSocket::Create(task_runner_, this, endpoint);
if (!create_result) {
OSP_LOG_ERROR << "failed to create socket (for " << endpoint
<< "): " << create_result.error().message();
continue;
}
- platform::UdpSocketUniquePtr server_socket =
- std::move(create_result.value());
+ std::unique_ptr<UdpSocket> server_socket = std::move(create_result.value());
server_socket->Bind();
sockets_.emplace_back(std::move(server_socket));
}
}
-void QuicConnectionFactoryImpl::OnRead(
- platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet_or_error) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionFactoryImpl::OnRead");
+void QuicConnectionFactoryImpl::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet_or_error) {
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionFactoryImpl::OnRead");
if (packet_or_error.is_error()) {
return;
}
- platform::UdpPacket packet = std::move(packet_or_error.value());
+ UdpPacket packet = std::move(packet_or_error.value());
// Ensure that |packet.socket| is one of the instances owned by
// QuicConnectionFactoryImpl.
auto packet_ptr = &packet;
OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(),
- [packet_ptr](const platform::UdpSocketUniquePtr& s) {
+ [packet_ptr](const std::unique_ptr<UdpSocket>& s) {
return s.get() == packet_ptr->socket();
}) != sockets_.end());
@@ -153,15 +148,14 @@ void QuicConnectionFactoryImpl::OnRead(
std::unique_ptr<QuicConnection> QuicConnectionFactoryImpl::Connect(
const IPEndpoint& endpoint,
QuicConnection::Delegate* connection_delegate) {
- auto create_result =
- platform::UdpSocket::Create(task_runner_, this, endpoint);
+ auto create_result = UdpSocket::Create(task_runner_, this, endpoint);
if (!create_result) {
OSP_LOG_ERROR << "failed to create socket: "
<< create_result.error().message();
// TODO(mfoltz): This method should return ErrorOr<uni_ptr<QuicConnection>>.
return nullptr;
}
- platform::UdpSocketUniquePtr socket = std::move(create_result.value());
+ std::unique_ptr<UdpSocket> socket = std::move(create_result.value());
auto transport = std::make_unique<UdpTransport>(socket.get(), endpoint);
::quic::QuartcSessionConfig session_config;
@@ -192,7 +186,7 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) {
return entry.second.connection == connection;
});
OSP_DCHECK(entry != connections_.end());
- platform::UdpSocket* const socket = entry->second.socket;
+ UdpSocket* const socket = entry->second.socket;
connections_.erase(entry);
// If none of the remaining |connections_| reference the socket, close/destroy
@@ -203,7 +197,7 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) {
}) == connections_.end()) {
auto socket_it =
std::find_if(sockets_.begin(), sockets_.end(),
- [socket](const platform::UdpSocketUniquePtr& s) {
+ [socket](const std::unique_ptr<UdpSocket>& s) {
return s.get() == socket;
});
OSP_DCHECK(socket_it != sockets_.end());
@@ -211,13 +205,11 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) {
}
}
-void QuicConnectionFactoryImpl::OnError(platform::UdpSocket* socket,
- Error error) {
+void QuicConnectionFactoryImpl::OnError(UdpSocket* socket, Error error) {
OSP_LOG_ERROR << "failed to configure socket " << error.message();
}
-void QuicConnectionFactoryImpl::OnSendError(platform::UdpSocket* socket,
- Error error) {
+void QuicConnectionFactoryImpl::OnSendError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h
index 8e29579d3b6..e3588f6a2f0 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h
@@ -22,14 +22,13 @@ class QuicTaskRunner;
class QuicConnectionFactoryImpl final : public QuicConnectionFactory {
public:
- QuicConnectionFactoryImpl(platform::TaskRunner* task_runner);
+ QuicConnectionFactoryImpl(TaskRunner* task_runner);
~QuicConnectionFactoryImpl() override;
// UdpSocket::Client overrides.
- void OnError(platform::UdpSocket* socket, Error error) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
// QuicConnectionFactory overrides.
void SetServerDelegate(ServerDelegate* delegate,
@@ -48,18 +47,18 @@ class QuicConnectionFactoryImpl final : public QuicConnectionFactory {
ServerDelegate* server_delegate_ = nullptr;
- std::vector<platform::UdpSocketUniquePtr> sockets_;
+ std::vector<std::unique_ptr<UdpSocket>> sockets_;
struct OpenConnection {
QuicConnection* connection;
- platform::UdpSocket* socket; // References one of the owned |sockets_|.
+ UdpSocket* socket; // References one of the owned |sockets_|.
};
- std::map<IPEndpoint, OpenConnection, IPEndpointComparator> connections_;
+ std::map<IPEndpoint, OpenConnection> connections_;
// NOTE: Must be provided in constructor and stored as an instance variable
// rather than using the static accessor method to allow for UTs to mock this
// layer.
- platform::TaskRunner* const task_runner_;
+ TaskRunner* const task_runner_;
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc
index 8ee43c332b8..778d6b4de82 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc
@@ -14,13 +14,10 @@
#include "util/logging.h"
#include "util/trace_logging.h"
-using openscreen::platform::TraceCategory;
-
namespace openscreen {
namespace osp {
-UdpTransport::UdpTransport(platform::UdpSocket* socket,
- const IPEndpoint& destination)
+UdpTransport::UdpTransport(UdpSocket* socket, const IPEndpoint& destination)
: socket_(socket), destination_(destination) {
OSP_DCHECK(socket_);
}
@@ -33,7 +30,7 @@ UdpTransport& UdpTransport::operator=(UdpTransport&&) noexcept = default;
int UdpTransport::Write(const char* buffer,
size_t buffer_length,
const PacketInfo& info) {
- TRACE_SCOPED(TraceCategory::Quic, "UdpTransport::Write");
+ TRACE_SCOPED(TraceCategory::kQuic, "UdpTransport::Write");
socket_->SendMessage(buffer, buffer_length, destination_);
OSP_DCHECK_LE(buffer_length,
static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -49,7 +46,7 @@ QuicStreamImpl::QuicStreamImpl(QuicStream::Delegate* delegate,
QuicStreamImpl::~QuicStreamImpl() = default;
void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::Write");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Write");
OSP_DCHECK(!stream_->write_side_closed());
stream_->WriteOrBufferData(
::quic::QuicStringPiece(reinterpret_cast<const char*>(data), data_size),
@@ -57,7 +54,7 @@ void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) {
}
void QuicStreamImpl::CloseWriteEnd() {
- TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::CloseWriteEnd");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::CloseWriteEnd");
if (!stream_->write_side_closed())
stream_->FinishWriting();
}
@@ -65,12 +62,12 @@ void QuicStreamImpl::CloseWriteEnd() {
void QuicStreamImpl::OnReceived(::quic::QuartcStream* stream,
const char* data,
size_t data_size) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnReceived");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::OnReceived");
delegate_->OnReceived(this, data, data_size);
}
void QuicStreamImpl::OnClose(::quic::QuartcStream* stream) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnClose");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::OnClose");
delegate_->OnClose(stream->id());
}
@@ -79,9 +76,8 @@ void QuicStreamImpl::OnBufferChanged(::quic::QuartcStream* stream) {}
// Passes a received UDP packet to the QUIC implementation. If this contains
// any stream data, it will be passed automatically to the relevant
// QuicStream::Delegate objects.
-void QuicConnectionImpl::OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> data) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnRead");
+void QuicConnectionImpl::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) {
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnRead");
if (data.is_error()) {
TRACE_SET_RESULT(data.error());
return;
@@ -91,12 +87,12 @@ void QuicConnectionImpl::OnRead(platform::UdpSocket* socket,
reinterpret_cast<const char*>(data.value().data()), data.value().size());
}
-void QuicConnectionImpl::OnSendError(platform::UdpSocket* socket, Error error) {
+void QuicConnectionImpl::OnSendError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
-void QuicConnectionImpl::OnError(platform::UdpSocket* socket, Error error) {
+void QuicConnectionImpl::OnError(UdpSocket* socket, Error error) {
// TODO(crbug.com/openscreen/67): Implement this method.
OSP_UNIMPLEMENTED();
}
@@ -110,7 +106,7 @@ QuicConnectionImpl::QuicConnectionImpl(
parent_factory_(parent_factory),
session_(std::move(session)),
udp_transport_(std::move(udp_transport)) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::QuicConnectionImpl");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::QuicConnectionImpl");
session_->SetDelegate(this);
session_->OnTransportCanWrite();
session_->StartCryptoHandshake();
@@ -120,24 +116,24 @@ QuicConnectionImpl::~QuicConnectionImpl() = default;
std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream(
QuicStream::Delegate* delegate) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::MakeOutgoingStream");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::MakeOutgoingStream");
::quic::QuartcStream* stream = session_->CreateOutgoingDynamicStream();
return std::make_unique<QuicStreamImpl>(delegate, stream);
}
void QuicConnectionImpl::Close() {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::Close");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::Close");
session_->CloseConnection("closed");
}
void QuicConnectionImpl::OnCryptoHandshakeComplete() {
- TRACE_SCOPED(TraceCategory::Quic,
+ TRACE_SCOPED(TraceCategory::kQuic,
"QuicConnectionImpl::OnCryptoHandshakeComplete");
delegate_->OnCryptoHandshakeComplete(session_->connection_id());
}
void QuicConnectionImpl::OnIncomingStream(::quic::QuartcStream* stream) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnIncomingStream");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnIncomingStream");
auto public_stream = std::make_unique<QuicStreamImpl>(
delegate_->NextStreamDelegate(session_->connection_id(), stream->id()),
stream);
@@ -150,7 +146,7 @@ void QuicConnectionImpl::OnConnectionClosed(
::quic::QuicErrorCode error_code,
const ::quic::QuicString& error_details,
::quic::ConnectionCloseSource source) {
- TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnConnectionClosed");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicConnectionImpl::OnConnectionClosed");
parent_factory_->OnConnectionClosed(this);
delegate_->OnConnectionClosed(session_->connection_id());
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h
index c7697ec117b..e2609deb694 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h
@@ -26,7 +26,7 @@ class QuicConnectionFactoryImpl;
class UdpTransport final : public ::quic::QuartcPacketTransport {
public:
- UdpTransport(platform::UdpSocket* socket, const IPEndpoint& destination);
+ UdpTransport(UdpSocket* socket, const IPEndpoint& destination);
UdpTransport(UdpTransport&&) noexcept;
~UdpTransport() override;
@@ -37,10 +37,10 @@ class UdpTransport final : public ::quic::QuartcPacketTransport {
size_t buffer_length,
const PacketInfo& info) override;
- platform::UdpSocket* socket() const { return socket_; }
+ UdpSocket* socket() const { return socket_; }
private:
- platform::UdpSocket* socket_;
+ UdpSocket* socket_;
IPEndpoint destination_;
};
@@ -76,10 +76,9 @@ class QuicConnectionImpl final : public QuicConnection,
~QuicConnectionImpl() override;
// UdpSocket::Client overrides.
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> data) override;
- void OnError(platform::UdpSocket* socket, Error error) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) override;
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
// QuicConnection overrides.
std::unique_ptr<QuicStream> MakeOutgoingStream(
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc
index 17a90973f94..e1afc58c712 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc
@@ -19,8 +19,8 @@ QuicServer::QuicServer(
MessageDemuxer* demuxer,
std::unique_ptr<QuicConnectionFactory> connection_factory,
ProtocolConnectionServer::Observer* observer,
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner)
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner)
: ProtocolConnectionServer(demuxer, observer),
connection_endpoints_(config.connection_endpoints),
connection_factory_(std::move(connection_factory)),
@@ -80,8 +80,7 @@ void QuicServer::Cleanup() {
}
delete_connections_.clear();
- constexpr platform::Clock::duration kQuicCleanupPeriod =
- std::chrono::milliseconds(500);
+ constexpr Clock::duration kQuicCleanupPeriod = std::chrono::milliseconds(500);
if (state_ != State::kStopped) {
cleanup_alarm_.ScheduleFromNow([this] { Cleanup(); }, kQuicCleanupPeriod);
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h
index 48652d53b16..ad195635461 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h
@@ -37,8 +37,8 @@ class QuicServer final : public ProtocolConnectionServer,
MessageDemuxer* demuxer,
std::unique_ptr<QuicConnectionFactory> connection_factory,
ProtocolConnectionServer::Observer* observer,
- platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner);
+ ClockNowFunctionPtr now_function,
+ TaskRunner* task_runner);
~QuicServer() override;
// ProtocolConnectionServer overrides.
@@ -84,15 +84,14 @@ class QuicServer final : public ProtocolConnectionServer,
// Maps an IPEndpoint to a generated endpoint ID. This is used to insulate
// callers from post-handshake changes to a connections actual peer endpoint.
- std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_;
+ std::map<IPEndpoint, uint64_t> endpoint_map_;
// Value that will be used for the next new endpoint in a Connect call.
uint64_t next_endpoint_id_ = 0;
// Maps endpoint addresses to data about connections that haven't successfully
// completed the QUIC handshake.
- std::map<IPEndpoint, ServiceConnectionData, IPEndpointComparator>
- pending_connections_;
+ std::map<IPEndpoint, ServiceConnectionData> pending_connections_;
// Maps endpoint IDs to data about connections that have successfully
// completed the QUIC handshake.
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc
index ed1c33627a8..12ad8c33b06 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc
@@ -49,12 +49,11 @@ class MockConnectionObserver final : public ProtocolConnection::Observer {
class QuicServerTest : public Test {
public:
QuicServerTest() {
- fake_clock_ = std::make_unique<platform::FakeClock>(
- platform::Clock::time_point(std::chrono::milliseconds(1298424)));
- task_runner_ =
- std::make_unique<platform::FakeTaskRunner>(fake_clock_.get());
- quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(),
- platform::FakeClock::now);
+ fake_clock_ = std::make_unique<FakeClock>(
+ Clock::time_point(std::chrono::milliseconds(1298424)));
+ task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get());
+ quic_bridge_ =
+ std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now);
}
protected:
@@ -103,17 +102,16 @@ class QuicServerTest : public Test {
EXPECT_CALL(mock_message_callback,
OnStreamMessage(
0, _, msgs::Type::kPresentationConnectionMessage, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_message](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionMessage(
- buffer, buffer_size, &received_message);
- if (decode_result < 0)
- return ErrorOr<size_t>(Error::Code::kCborParsing);
- return ErrorOr<size_t>(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_message](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionMessage(
+ buffer, buffer_size, &received_message);
+ if (decode_result < 0)
+ return ErrorOr<size_t>(Error::Code::kCborParsing);
+ return ErrorOr<size_t>(decode_result);
+ }));
quic_bridge_->RunTasksUntilIdle();
ASSERT_GT(decode_result, 0);
@@ -124,8 +122,8 @@ class QuicServerTest : public Test {
EXPECT_EQ(received_message.message.str, message.message.str);
}
- std::unique_ptr<platform::FakeClock> fake_clock_;
- std::unique_ptr<platform::FakeTaskRunner> task_runner_;
+ std::unique_ptr<FakeClock> fake_clock_;
+ std::unique_ptr<FakeTaskRunner> task_runner_;
std::unique_ptr<FakeQuicBridge> quic_bridge_;
QuicServer* server_;
};
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc
index c04ef4331c3..52f7241cb3a 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc
@@ -61,16 +61,15 @@ std::unique_ptr<FakeQuicStream> FakeQuicConnection::MakeIncomingStream() {
return result;
}
-void FakeQuicConnection::OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> data) {
+void FakeQuicConnection::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) {
OSP_NOTREACHED() << "data should go directly to fake streams";
}
-void FakeQuicConnection::OnSendError(platform::UdpSocket* socket, Error error) {
+void FakeQuicConnection::OnSendError(UdpSocket* socket, Error error) {
OSP_NOTREACHED() << "data should go directly to fake streams";
}
-void FakeQuicConnection::OnError(platform::UdpSocket* socket, Error error) {
+void FakeQuicConnection::OnError(UdpSocket* socket, Error error) {
OSP_NOTREACHED() << "data should go directly to fake streams";
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h
index 47943f0b8df..9b11d7b4d55 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h
@@ -58,10 +58,9 @@ class FakeQuicConnection final : public QuicConnection {
std::unique_ptr<FakeQuicStream> MakeIncomingStream();
// UdpSocket::Client overrides.
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> data) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnError(platform::UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> data) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnError(UdpSocket* socket, Error error) override;
// QuicConnection overrides.
std::unique_ptr<QuicStream> MakeOutgoingStream(
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc
index c68d256c6f6..47f3ab8757d 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc
@@ -167,19 +167,17 @@ std::unique_ptr<QuicConnection> FakeClientQuicConnectionFactory::Connect(
return bridge_->Connect(endpoint, connection_delegate);
}
-void FakeClientQuicConnectionFactory::OnError(platform::UdpSocket* socket,
- Error error) {
+void FakeClientQuicConnectionFactory::OnError(UdpSocket* socket, Error error) {
OSP_UNIMPLEMENTED();
}
-void FakeClientQuicConnectionFactory::OnSendError(platform::UdpSocket* socket,
+void FakeClientQuicConnectionFactory::OnSendError(UdpSocket* socket,
Error error) {
OSP_UNIMPLEMENTED();
}
-void FakeClientQuicConnectionFactory::OnRead(
- platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) {
+void FakeClientQuicConnectionFactory::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet) {
bridge_->RunTasks(true);
idle_ = bridge_->client_idle();
}
@@ -207,19 +205,17 @@ std::unique_ptr<QuicConnection> FakeServerQuicConnectionFactory::Connect(
return nullptr;
}
-void FakeServerQuicConnectionFactory::OnError(platform::UdpSocket* socket,
- Error error) {
+void FakeServerQuicConnectionFactory::OnError(UdpSocket* socket, Error error) {
OSP_UNIMPLEMENTED();
}
-void FakeServerQuicConnectionFactory::OnSendError(platform::UdpSocket* socket,
+void FakeServerQuicConnectionFactory::OnSendError(UdpSocket* socket,
Error error) {
OSP_UNIMPLEMENTED();
}
-void FakeServerQuicConnectionFactory::OnRead(
- platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) {
+void FakeServerQuicConnectionFactory::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet) {
bridge_->RunTasks(false);
idle_ = bridge_->server_idle();
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h
index 128e6372241..2aac1145359 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h
@@ -55,10 +55,9 @@ class FakeClientQuicConnectionFactory final : public QuicConnectionFactory {
~FakeClientQuicConnectionFactory() override;
// UdpSocket::Client overrides.
- void OnError(platform::UdpSocket* socket, Error error) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
// QuicConnectionFactory overrides.
void SetServerDelegate(ServerDelegate* delegate,
@@ -69,7 +68,7 @@ class FakeClientQuicConnectionFactory final : public QuicConnectionFactory {
bool idle() const { return idle_; }
- std::unique_ptr<platform::UdpSocket> socket_;
+ std::unique_ptr<UdpSocket> socket_;
private:
FakeQuicConnectionFactoryBridge* bridge_;
@@ -83,10 +82,9 @@ class FakeServerQuicConnectionFactory final : public QuicConnectionFactory {
~FakeServerQuicConnectionFactory() override;
// UdpSocket::Client overrides.
- void OnError(platform::UdpSocket* socket, Error error) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
+ void OnError(UdpSocket* socket, Error error) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
// QuicConnectionFactory overrides.
void SetServerDelegate(ServerDelegate* delegate,
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc
index ed5a7007602..2ca327139f1 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc
@@ -14,8 +14,8 @@
namespace openscreen {
namespace osp {
-FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner,
- platform::ClockNowFunctionPtr now_function)
+FakeQuicBridge::FakeQuicBridge(FakeTaskRunner* task_runner,
+ ClockNowFunctionPtr now_function)
: task_runner_(task_runner) {
fake_bridge =
std::make_unique<FakeQuicConnectionFactoryBridge>(kControllerEndpoint);
@@ -27,8 +27,8 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner,
auto fake_client_factory =
std::make_unique<FakeClientQuicConnectionFactory>(fake_bridge.get());
- client_socket_ = std::make_unique<platform::FakeUdpSocket>(
- task_runner_, fake_client_factory.get());
+ client_socket_ =
+ std::make_unique<FakeUdpSocket>(task_runner_, fake_client_factory.get());
quic_client = std::make_unique<QuicClient>(
controller_demuxer.get(), std::move(fake_client_factory),
@@ -36,8 +36,8 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner,
auto fake_server_factory =
std::make_unique<FakeServerQuicConnectionFactory>(fake_bridge.get());
- server_socket_ = std::make_unique<platform::FakeUdpSocket>(
- task_runner_, fake_server_factory.get());
+ server_socket_ =
+ std::make_unique<FakeUdpSocket>(task_runner_, fake_server_factory.get());
ServerConfig config;
config.connection_endpoints.push_back(kReceiverEndpoint);
quic_server = std::make_unique<QuicServer>(
@@ -51,13 +51,13 @@ FakeQuicBridge::FakeQuicBridge(platform::FakeTaskRunner* task_runner,
FakeQuicBridge::~FakeQuicBridge() = default;
void FakeQuicBridge::PostClientPacket() {
- platform::UdpPacket packet;
+ UdpPacket packet;
packet.set_socket(client_socket_.get());
client_socket_->MockReceivePacket(std::move(packet));
}
void FakeQuicBridge::PostServerPacket() {
- platform::UdpPacket packet;
+ UdpPacket packet;
packet.set_socket(server_socket_.get());
server_socket_->MockReceivePacket(std::move(packet));
}
diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h
index 3e8719aa25a..a7ae7c33b63 100644
--- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h
+++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h
@@ -56,8 +56,7 @@ class MockServerObserver : public ProtocolConnectionServer::Observer {
class FakeQuicBridge {
public:
- FakeQuicBridge(platform::FakeTaskRunner* task_runner,
- platform::ClockNowFunctionPtr now_function);
+ FakeQuicBridge(FakeTaskRunner* task_runner, ClockNowFunctionPtr now_function);
~FakeQuicBridge();
const IPEndpoint kControllerEndpoint{{192, 168, 1, 3}, 4321};
@@ -79,10 +78,10 @@ class FakeQuicBridge {
void PostPacketsUntilIdle();
FakeClientQuicConnectionFactory* GetClientFactory();
FakeServerQuicConnectionFactory* GetServerFactory();
- platform::FakeTaskRunner* task_runner_;
+ FakeTaskRunner* task_runner_;
- std::unique_ptr<platform::FakeUdpSocket> client_socket_;
- std::unique_ptr<platform::FakeUdpSocket> server_socket_;
+ std::unique_ptr<FakeUdpSocket> client_socket_;
+ std::unique_ptr<FakeUdpSocket> server_socket_;
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc
index f0963014601..43e5479c9f2 100644
--- a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc
@@ -4,6 +4,8 @@
#include "osp/impl/service_listener_impl.h"
+#include <algorithm>
+
#include "platform/base/error.h"
#include "util/logging.h"
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc
index 7930eedf779..371e8789acc 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.cc
@@ -16,7 +16,7 @@ FakeMdnsPlatformService::~FakeMdnsPlatformService() = default;
std::vector<MdnsPlatformService::BoundInterface>
FakeMdnsPlatformService::RegisterInterfaces(
- const std::vector<platform::NetworkInterfaceIndex>& whitelist) {
+ const std::vector<NetworkInterfaceIndex>& whitelist) {
OSP_CHECK(registered_interfaces_.empty());
if (whitelist.empty()) {
registered_interfaces_ = interfaces_;
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h
index 81e148c16c8..2c5d6b88794 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service.h
@@ -23,8 +23,8 @@ class FakeMdnsPlatformService final : public MdnsPlatformService {
// PlatformService overrides.
std::vector<BoundInterface> RegisterInterfaces(
- const std::vector<platform::NetworkInterfaceIndex>&
- interface_index_whitelist) override;
+ const std::vector<NetworkInterfaceIndex>& interface_index_whitelist)
+ override;
void DeregisterInterfaces(
const std::vector<BoundInterface>& registered_interfaces) override;
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc
index 69595ac3598..c2058d4d254 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_platform_service_unittest.cc
@@ -12,32 +12,33 @@ namespace openscreen {
namespace osp {
namespace {
-platform::UdpSocket* const kDefaultSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(16));
-platform::UdpSocket* const kSecondSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(24));
+UdpSocket* const kDefaultSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16));
+UdpSocket* const kSecondSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24));
class FakeMdnsPlatformServiceTest : public ::testing::Test {
protected:
const uint8_t mac1_[6] = {11, 22, 33, 44, 55, 66};
const uint8_t mac2_[6] = {12, 23, 34, 45, 56, 67};
- const platform::IPSubnet subnet1_{IPAddress{192, 168, 3, 2}, 24};
- const platform::IPSubnet subnet2_{
- IPAddress{1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 8}, 24};
+ const IPSubnet subnet1_{IPAddress{192, 168, 3, 2}, 24};
+ const IPSubnet subnet2_{
+ IPAddress{0x0102, 0x0304, 0x0504, 0x0302, 0x0102, 0x0304, 0x0506, 0x0708},
+ 24};
std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{
MdnsPlatformService::BoundInterface{
- platform::InterfaceInfo{1,
- mac1_,
- "eth0",
- platform::InterfaceInfo::Type::kEthernet,
- {subnet1_}},
+ InterfaceInfo{1,
+ mac1_,
+ "eth0",
+ InterfaceInfo::Type::kEthernet,
+ {subnet1_}},
subnet1_, kDefaultSocket},
MdnsPlatformService::BoundInterface{
- platform::InterfaceInfo{2,
- mac2_,
- "eth1",
- platform::InterfaceInfo::Type::kEthernet,
- {subnet2_}},
+ InterfaceInfo{2,
+ mac2_,
+ "eth1",
+ InterfaceInfo::Type::kEthernet,
+ {subnet2_}},
subnet2_, kSecondSocket}};
};
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc
index 37e744348ee..fad1c125b30 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc
@@ -17,7 +17,7 @@ constexpr char kLocalDomain[] = "local";
PtrEvent MakePtrEvent(const std::string& service_instance,
const std::string& service_type,
const std::string& service_protocol,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
const auto labels = std::vector<std::string>{service_instance, service_type,
service_protocol, kLocalDomain};
ErrorOr<DomainName> full_instance_name =
@@ -33,7 +33,7 @@ SrvEvent MakeSrvEvent(const std::string& service_instance,
const std::string& service_protocol,
const std::string& hostname,
uint16_t port,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
const auto instance_labels = std::vector<std::string>{
service_instance, service_type, service_protocol, kLocalDomain};
ErrorOr<DomainName> full_instance_name =
@@ -54,7 +54,7 @@ TxtEvent MakeTxtEvent(const std::string& service_instance,
const std::string& service_type,
const std::string& service_protocol,
const std::vector<std::string>& txt_lines,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
const auto labels = std::vector<std::string>{service_instance, service_type,
service_protocol, kLocalDomain};
ErrorOr<DomainName> domain_name =
@@ -67,7 +67,7 @@ TxtEvent MakeTxtEvent(const std::string& service_instance,
AEvent MakeAEvent(const std::string& hostname,
IPAddress address,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
const auto labels = std::vector<std::string>{hostname, kLocalDomain};
ErrorOr<DomainName> domain_name =
DomainName::FromLabels(labels.begin(), labels.end());
@@ -79,7 +79,7 @@ AEvent MakeAEvent(const std::string& hostname,
AaaaEvent MakeAaaaEvent(const std::string& hostname,
IPAddress address,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
const auto labels = std::vector<std::string>{hostname, kLocalDomain};
ErrorOr<DomainName> domain_name =
DomainName::FromLabels(labels.begin(), labels.end());
@@ -97,7 +97,7 @@ void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder,
uint16_t port,
const std::vector<std::string>& txt_lines,
const IPAddress& address,
- platform::UdpSocket* socket) {
+ UdpSocket* socket) {
mdns_responder->AddPtrEvent(
MakePtrEvent(service_instance, service_name, service_protocol, socket));
mdns_responder->AddSrvEvent(MakeSrvEvent(service_instance, service_name,
@@ -202,9 +202,9 @@ Error FakeMdnsResponderAdapter::SetHostLabel(const std::string& host_label) {
}
Error FakeMdnsResponderAdapter::RegisterInterface(
- const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocket* socket) {
+ const InterfaceInfo& interface_info,
+ const IPSubnet& interface_address,
+ UdpSocket* socket) {
if (!running_)
return Error::Code::kOperationInvalid;
@@ -218,8 +218,7 @@ Error FakeMdnsResponderAdapter::RegisterInterface(
return Error::None();
}
-Error FakeMdnsResponderAdapter::DeregisterInterface(
- platform::UdpSocket* socket) {
+Error FakeMdnsResponderAdapter::DeregisterInterface(UdpSocket* socket) {
auto it =
std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(),
[&socket](const RegisteredInterface& interface) {
@@ -232,22 +231,20 @@ Error FakeMdnsResponderAdapter::DeregisterInterface(
return Error::None();
}
-void FakeMdnsResponderAdapter::OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) {
+void FakeMdnsResponderAdapter::OnRead(UdpSocket* socket,
+ ErrorOr<UdpPacket> packet) {
OSP_NOTREACHED() << "Tests should not drive this class with packets";
}
-void FakeMdnsResponderAdapter::OnSendError(platform::UdpSocket* socket,
- Error error) {
+void FakeMdnsResponderAdapter::OnSendError(UdpSocket* socket, Error error) {
OSP_NOTREACHED() << "Tests should not drive this class with packets";
}
-void FakeMdnsResponderAdapter::OnError(platform::UdpSocket* socket,
- Error error) {
+void FakeMdnsResponderAdapter::OnError(UdpSocket* socket, Error error) {
OSP_NOTREACHED() << "Tests should not drive this class with packets";
}
-platform::Clock::duration FakeMdnsResponderAdapter::RunTasks() {
+Clock::duration FakeMdnsResponderAdapter::RunTasks() {
return std::chrono::seconds(1);
}
@@ -369,7 +366,7 @@ std::vector<AaaaEvent> FakeMdnsResponderAdapter::TakeAaaaResponses() {
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StartPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) {
if (!running_)
return MdnsResponderErrorCode::kUnknownError;
@@ -388,7 +385,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartPtrQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StartSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
if (!running_)
return MdnsResponderErrorCode::kUnknownError;
@@ -402,7 +399,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartSrvQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StartTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
if (!running_)
return MdnsResponderErrorCode::kUnknownError;
@@ -416,7 +413,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartTxtQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
if (!running_)
return MdnsResponderErrorCode::kUnknownError;
@@ -430,7 +427,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
if (!running_)
return MdnsResponderErrorCode::kUnknownError;
@@ -444,7 +441,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAaaaQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StopPtrQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_type) {
auto interface_entry = queries_.find(socket);
if (interface_entry == queries_.end())
@@ -463,7 +460,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopPtrQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StopSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
auto interface_entry = queries_.find(socket);
if (interface_entry == queries_.end())
@@ -478,7 +475,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopSrvQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StopTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) {
auto interface_entry = queries_.find(socket);
if (interface_entry == queries_.end())
@@ -493,7 +490,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopTxtQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
auto interface_entry = queries_.find(socket);
if (interface_entry == queries_.end())
@@ -508,7 +505,7 @@ MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAQuery(
}
MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAaaaQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& domain_name) {
auto interface_entry = queries_.find(socket);
if (interface_entry == queries_.end())
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h
index a7287e45ddb..d4fdad1fea0 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h
@@ -18,28 +18,28 @@ class FakeMdnsResponderAdapter;
PtrEvent MakePtrEvent(const std::string& service_instance,
const std::string& service_type,
const std::string& service_protocol,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
SrvEvent MakeSrvEvent(const std::string& service_instance,
const std::string& service_type,
const std::string& service_protocol,
const std::string& hostname,
uint16_t port,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
TxtEvent MakeTxtEvent(const std::string& service_instance,
const std::string& service_type,
const std::string& service_protocol,
const std::vector<std::string>& txt_lines,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
AEvent MakeAEvent(const std::string& hostname,
IPAddress address,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
AaaaEvent MakeAaaaEvent(const std::string& hostname,
IPAddress address,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder,
const std::string& service_instance,
@@ -49,14 +49,14 @@ void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder,
uint16_t port,
const std::vector<std::string>& txt_lines,
const IPAddress& address,
- platform::UdpSocket* socket);
+ UdpSocket* socket);
class FakeMdnsResponderAdapter final : public MdnsResponderAdapter {
public:
struct RegisteredInterface {
- platform::InterfaceInfo interface_info;
- platform::IPSubnet interface_address;
- platform::UdpSocket* socket;
+ InterfaceInfo interface_info;
+ IPSubnet interface_address;
+ UdpSocket* socket;
};
struct RegisteredService {
@@ -99,10 +99,9 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter {
bool running() const { return running_; }
// UdpSocket::Client overrides.
- void OnRead(platform::UdpSocket* socket,
- ErrorOr<platform::UdpPacket> packet) override;
- void OnSendError(platform::UdpSocket* socket, Error error) override;
- void OnError(platform::UdpSocket* socket, Error error) override;
+ void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override;
+ void OnSendError(UdpSocket* socket, Error error) override;
+ void OnError(UdpSocket* socket, Error error) override;
// MdnsResponderAdapter overrides.
Error Init() override;
@@ -112,12 +111,12 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter {
// TODO(btolsch): Reject/OSP_CHECK events that don't match any registered
// interface?
- Error RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocket* socket) override;
- Error DeregisterInterface(platform::UdpSocket* socket) override;
+ Error RegisterInterface(const InterfaceInfo& interface_info,
+ const IPSubnet& interface_address,
+ UdpSocket* socket) override;
+ Error DeregisterInterface(UdpSocket* socket) override;
- platform::Clock::duration RunTasks() override;
+ Clock::duration RunTasks() override;
std::vector<PtrEvent> TakePtrResponses() override;
std::vector<SrvEvent> TakeSrvResponses() override;
@@ -125,30 +124,30 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter {
std::vector<AEvent> TakeAResponses() override;
std::vector<AaaaEvent> TakeAaaaResponses() override;
- MdnsResponderErrorCode StartPtrQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket,
const DomainName& service_type) override;
MdnsResponderErrorCode StartSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
MdnsResponderErrorCode StartTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
- MdnsResponderErrorCode StartAQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartAQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StartAaaaQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StopPtrQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket,
const DomainName& service_type) override;
MdnsResponderErrorCode StopSrvQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
MdnsResponderErrorCode StopTxtQuery(
- platform::UdpSocket* socket,
+ UdpSocket* socket,
const DomainName& service_instance) override;
- MdnsResponderErrorCode StopAQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopAQuery(UdpSocket* socket,
const DomainName& domain_name) override;
- MdnsResponderErrorCode StopAaaaQuery(platform::UdpSocket* socket,
+ MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket,
const DomainName& domain_name) override;
MdnsResponderErrorCode RegisterService(
@@ -180,7 +179,7 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter {
bool running_ = false;
LifetimeObserver* observer_ = nullptr;
- std::map<platform::UdpSocket*, InterfaceQueries> queries_;
+ std::map<UdpSocket*, InterfaceQueries> queries_;
// NOTE: One of many simplifications here is that there is no cache. This
// means that calling StartQuery, StopQuery, StartQuery will only return an
// event the first time, unless the test also adds the event a second time.
diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc
index f11ef33bbee..06cec89a5ad 100644
--- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc
@@ -15,10 +15,10 @@ constexpr char kTestServiceInstance[] = "turtle";
constexpr char kTestServiceName[] = "_foo";
constexpr char kTestServiceProtocol[] = "_udp";
-platform::UdpSocket* const kDefaultSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(8));
-platform::UdpSocket* const kSecondSocket =
- reinterpret_cast<platform::UdpSocket*>(static_cast<uintptr_t>(32));
+UdpSocket* const kDefaultSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(8));
+UdpSocket* const kSecondSocket =
+ reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(32));
} // namespace
@@ -260,18 +260,18 @@ TEST(FakeMdnsResponderAdapterTest, RegisterInterfaces) {
ASSERT_TRUE(mdns_responder.running());
EXPECT_EQ(0u, mdns_responder.registered_interfaces().size());
- Error result = mdns_responder.RegisterInterface(
- platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket);
+ Error result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{},
+ kDefaultSocket);
EXPECT_TRUE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
- result = mdns_responder.RegisterInterface(
- platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket);
+ result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{},
+ kDefaultSocket);
EXPECT_FALSE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
- result = mdns_responder.RegisterInterface(
- platform::InterfaceInfo{}, platform::IPSubnet{}, kSecondSocket);
+ result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{},
+ kSecondSocket);
EXPECT_TRUE(result.ok());
EXPECT_EQ(2u, mdns_responder.registered_interfaces().size());
@@ -286,8 +286,8 @@ TEST(FakeMdnsResponderAdapterTest, RegisterInterfaces) {
ASSERT_FALSE(mdns_responder.running());
EXPECT_EQ(0u, mdns_responder.registered_interfaces().size());
- result = mdns_responder.RegisterInterface(
- platform::InterfaceInfo{}, platform::IPSubnet{}, kDefaultSocket);
+ result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{},
+ kDefaultSocket);
EXPECT_FALSE(result.ok());
EXPECT_EQ(0u, mdns_responder.registered_interfaces().size());
}
diff --git a/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h b/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h
index 579739c7dad..82874420b12 100644
--- a/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h
+++ b/chromium/third_party/openscreen/src/osp/msgs/request_response_handler.h
@@ -160,7 +160,7 @@ class RequestResponseHandler : public MessageDemuxer::MessageCallback {
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) override {
+ Clock::time_point now) override {
if (message_type != RequestT::kResponseType) {
return 0;
}
diff --git a/chromium/third_party/openscreen/src/osp/public/client_config.h b/chromium/third_party/openscreen/src/osp/public/client_config.h
index 732b52b6fd2..ae172377d61 100644
--- a/chromium/third_party/openscreen/src/osp/public/client_config.h
+++ b/chromium/third_party/openscreen/src/osp/public/client_config.h
@@ -19,8 +19,8 @@ struct ClientConfig {
// The indexes of network interfaces that should be used by the Open Screen
// Library. The indexes derive from the values of
- // openscreen::platform::InterfaceInfo::index.
- std::vector<platform::NetworkInterfaceIndex> interface_indexes;
+ // openscreen::InterfaceInfo::index.
+ std::vector<NetworkInterfaceIndex> interface_indexes;
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h
index 73118bdaaa9..663d060fbbe 100644
--- a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h
+++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h
@@ -11,9 +11,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -27,7 +25,7 @@ class MdnsServiceListenerFactory {
static std::unique_ptr<ServiceListener> Create(
const MdnsServiceListenerConfig& config,
ServiceListener::Observer* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h
index 4710d0fb437..075137a7243 100644
--- a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h
+++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h
@@ -11,9 +11,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -22,7 +20,7 @@ class MdnsServicePublisherFactory {
static std::unique_ptr<ServicePublisher> Create(
const ServicePublisher::Config& config,
ServicePublisher::Observer* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc
index cabc910fcfd..d67f848a140 100644
--- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc
+++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc
@@ -116,7 +116,7 @@ MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
return *this;
}
-MessageDemuxer::MessageDemuxer(platform::ClockNowFunctionPtr now_function,
+MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
size_t buffer_limit = kDefaultBufferLimit)
: now_function_(now_function), buffer_limit_(buffer_limit) {
OSP_DCHECK(now_function_);
diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h
index 5592ee8d430..284e43ba8fa 100644
--- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h
+++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h
@@ -33,13 +33,12 @@ class MessageDemuxer {
// error code of Error::Code::kCborIncompleteMessage. This way,
// the MessageDemuxer knows to neither consume the data nor discard it as
// bad.
- virtual ErrorOr<size_t> OnStreamMessage(
- uint64_t endpoint_id,
- uint64_t connection_id,
- msgs::Type message_type,
- const uint8_t* buffer,
- size_t buffer_size,
- platform::Clock::time_point now) = 0;
+ virtual ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size,
+ Clock::time_point now) = 0;
};
class MessageWatch {
@@ -64,8 +63,7 @@ class MessageDemuxer {
static constexpr size_t kDefaultBufferLimit = 1 << 16;
- MessageDemuxer(platform::ClockNowFunctionPtr now_function,
- size_t buffer_limit);
+ MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit);
~MessageDemuxer();
// Starts watching for messages of type |message_type| from the endpoint
@@ -110,7 +108,7 @@ class MessageDemuxer {
std::map<msgs::Type, MessageCallback*>* message_callbacks,
std::vector<uint8_t>* buffer);
- const platform::ClockNowFunctionPtr now_function_;
+ const ClockNowFunctionPtr now_function_;
const size_t buffer_limit_;
std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_;
std::map<msgs::Type, MessageCallback*> default_callbacks_;
diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc b/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc
index 7e830815d7a..63a03735770 100644
--- a/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc
+++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc
@@ -48,14 +48,12 @@ class MessageDemuxerTest : public ::testing::Test {
const uint64_t endpoint_id_ = 13;
const uint64_t connection_id_ = 45;
- platform::FakeClock fake_clock_{
- platform::Clock::time_point(std::chrono::milliseconds(1298424))};
+ FakeClock fake_clock_{Clock::time_point(std::chrono::milliseconds(1298424))};
msgs::CborEncodeBuffer buffer_;
msgs::PresentationConnectionOpenRequest request_{1, "fry-am-the-egg-man",
"url"};
MockMessageCallback mock_callback_;
- MessageDemuxer demuxer_{platform::FakeClock::now,
- MessageDemuxer::kDefaultBufferLimit};
+ MessageDemuxer demuxer_{FakeClock::now, MessageDemuxer::kDefaultBufferLimit};
};
} // namespace
@@ -75,15 +73,14 @@ TEST_F(MessageDemuxerTest, WatchStartStop) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size());
ExpectDecodedRequest(decode_result, received_request);
@@ -110,15 +107,14 @@ TEST_F(MessageDemuxerTest, BufferPartialMessage) {
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
.Times(2)
- .WillRepeatedly(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillRepeatedly(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size() - 3);
demuxer_.OnStreamData(endpoint_id_, connection_id_,
@@ -140,15 +136,14 @@ TEST_F(MessageDemuxerTest, DefaultWatch) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size());
ExpectDecodedRequest(decode_result, received_request);
@@ -176,15 +171,14 @@ TEST_F(MessageDemuxerTest, DefaultWatchOverridden) {
mock_callback_global,
OnStreamMessage(endpoint_id_ + 1, 14,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
ExpectDecodedRequest(decode_result, received_request);
@@ -193,15 +187,14 @@ TEST_F(MessageDemuxerTest, DefaultWatchOverridden) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size());
ExpectDecodedRequest(decode_result, received_request);
@@ -214,15 +207,14 @@ TEST_F(MessageDemuxerTest, WatchAfterData) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
&mock_callback_);
@@ -245,27 +237,25 @@ TEST_F(MessageDemuxerTest, WatchAfterMultipleData) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result1, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result1 = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result1);
- }));
+ .WillOnce(Invoke([&decode_result1, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result1 = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result1);
+ }));
EXPECT_CALL(mock_init_callback,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationStartRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result2, &received_init_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result2 = msgs::DecodePresentationStartRequest(
- buffer, buffer_size, &received_init_request);
- return ConvertDecodeResult(decode_result2);
- }));
+ .WillOnce(Invoke([&decode_result2, &received_init_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result2 = msgs::DecodePresentationStartRequest(
+ buffer, buffer_size, &received_init_request);
+ return ConvertDecodeResult(decode_result2);
+ }));
MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
&mock_callback_);
@@ -296,15 +286,14 @@ TEST_F(MessageDemuxerTest, GlobalWatchAfterData) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
ASSERT_TRUE(watch);
@@ -314,7 +303,7 @@ TEST_F(MessageDemuxerTest, GlobalWatchAfterData) {
}
TEST_F(MessageDemuxerTest, BufferLimit) {
- MessageDemuxer demuxer(platform::FakeClock::now, 10);
+ MessageDemuxer demuxer(FakeClock::now, 10);
demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size());
@@ -329,15 +318,14 @@ TEST_F(MessageDemuxerTest, BufferLimit) {
mock_callback_,
OnStreamMessage(endpoint_id_, connection_id_,
msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
- .WillOnce(
- Invoke([&decode_result, &received_request](
- uint64_t endpoint_id, uint64_t connection_id,
- msgs::Type message_type, const uint8_t* buffer,
- size_t buffer_size, platform::Clock::time_point now) {
- decode_result = msgs::DecodePresentationConnectionOpenRequest(
- buffer, buffer_size, &received_request);
- return ConvertDecodeResult(decode_result);
- }));
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size, Clock::time_point now) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
buffer_.size());
ExpectDecodedRequest(decode_result, received_request);
diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h
index 5382fec3eba..293e942ff4b 100644
--- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h
+++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_connection.h
@@ -196,7 +196,7 @@ class ConnectionManager final : public MessageDemuxer::MessageCallback {
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) override;
+ Clock::time_point now) override;
Connection* GetConnection(uint64_t connection_id);
diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h
index 6d51329d1ba..d0a5ee8197c 100644
--- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h
+++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_controller.h
@@ -96,7 +96,7 @@ class Controller final : public ServiceListener::Observer,
Controller* controller_;
};
- explicit Controller(platform::ClockNowFunctionPtr now_function);
+ explicit Controller(ClockNowFunctionPtr now_function);
~Controller();
// Requests receivers compatible with all urls in |urls| and registers
diff --git a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h
index 3f87c0c5418..4eb4a04c26b 100644
--- a/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h
+++ b/chromium/third_party/openscreen/src/osp/public/presentation/presentation_receiver.h
@@ -97,7 +97,7 @@ class Receiver final : public MessageDemuxer::MessageCallback,
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now) override;
+ Clock::time_point now) override;
private:
struct QueuedResponse {
diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h
index 8672c16b7eb..e30d6216634 100644
--- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h
+++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h
@@ -11,9 +11,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-} // namespace platform
namespace osp {
@@ -22,7 +20,7 @@ class ProtocolConnectionClientFactory {
static std::unique_ptr<ProtocolConnectionClient> Create(
MessageDemuxer* demuxer,
ProtocolConnectionServiceObserver* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h
index 0f11b8a2752..0e55eeddc5e 100644
--- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h
+++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h
@@ -12,9 +12,7 @@
namespace openscreen {
-namespace platform {
class TaskRunner;
-}
namespace osp {
@@ -24,7 +22,7 @@ class ProtocolConnectionServerFactory {
const ServerConfig& config,
MessageDemuxer* demuxer,
ProtocolConnectionServer::Observer* observer,
- platform::TaskRunner* task_runner);
+ TaskRunner* task_runner);
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/osp/public/server_config.h b/chromium/third_party/openscreen/src/osp/public/server_config.h
index 579fe36166e..e4f95d8a6c1 100644
--- a/chromium/third_party/openscreen/src/osp/public/server_config.h
+++ b/chromium/third_party/openscreen/src/osp/public/server_config.h
@@ -20,8 +20,8 @@ struct ServerConfig {
// The indexes of network interfaces that should be used by the Open Screen
// Library. The indexes derive from the values of
- // openscreen::platform::InterfaceInfo::index.
- std::vector<platform::NetworkInterfaceIndex> interface_indexes;
+ // openscreen::InterfaceInfo::index.
+ std::vector<NetworkInterfaceIndex> interface_indexes;
// The list of connection endpoints that are advertised for Open Screen
// protocol connections. These must be reachable via one interface in
diff --git a/chromium/third_party/openscreen/src/osp/public/service_info.cc b/chromium/third_party/openscreen/src/osp/public/service_info.cc
index ebd41b2779f..77943ca46e1 100644
--- a/chromium/third_party/openscreen/src/osp/public/service_info.cc
+++ b/chromium/third_party/openscreen/src/osp/public/service_info.cc
@@ -23,11 +23,10 @@ bool ServiceInfo::operator!=(const ServiceInfo& other) const {
return !(*this == other);
}
-bool ServiceInfo::Update(
- std::string new_friendly_name,
- platform::NetworkInterfaceIndex new_network_interface_index,
- const IPEndpoint& new_v4_endpoint,
- const IPEndpoint& new_v6_endpoint) {
+bool ServiceInfo::Update(std::string new_friendly_name,
+ NetworkInterfaceIndex new_network_interface_index,
+ const IPEndpoint& new_v4_endpoint,
+ const IPEndpoint& new_v6_endpoint) {
OSP_DCHECK(!new_v4_endpoint.address ||
IPAddress::Version::kV4 == new_v4_endpoint.address.version());
OSP_DCHECK(!new_v6_endpoint.address ||
diff --git a/chromium/third_party/openscreen/src/osp/public/service_info.h b/chromium/third_party/openscreen/src/osp/public/service_info.h
index df9d8e3638d..73e43eb6303 100644
--- a/chromium/third_party/openscreen/src/osp/public/service_info.h
+++ b/chromium/third_party/openscreen/src/osp/public/service_info.h
@@ -29,7 +29,7 @@ struct ServiceInfo {
bool operator!=(const ServiceInfo& other) const;
bool Update(std::string friendly_name,
- platform::NetworkInterfaceIndex network_interface_index,
+ NetworkInterfaceIndex network_interface_index,
const IPEndpoint& v4_endpoint,
const IPEndpoint& v6_endpoint);
@@ -40,8 +40,7 @@ struct ServiceInfo {
std::string friendly_name;
// The index of the network interface that the screen was discovered on.
- platform::NetworkInterfaceIndex network_interface_index =
- platform::kInvalidNetworkInterfaceIndex;
+ NetworkInterfaceIndex network_interface_index = kInvalidNetworkInterfaceIndex;
// The network endpoints to create a new connection to the Open Screen
// service.
diff --git a/chromium/third_party/openscreen/src/osp/public/service_listener.h b/chromium/third_party/openscreen/src/osp/public/service_listener.h
index 5d81536d429..2a11e44c42b 100644
--- a/chromium/third_party/openscreen/src/osp/public/service_listener.h
+++ b/chromium/third_party/openscreen/src/osp/public/service_listener.h
@@ -5,7 +5,6 @@
#ifndef OSP_PUBLIC_SERVICE_LISTENER_H_
#define OSP_PUBLIC_SERVICE_LISTENER_H_
-#include <atomic>
#include <cstdint>
#include <string>
#include <vector>
@@ -150,7 +149,7 @@ class ServiceListener {
protected:
ServiceListener();
- std::atomic<State> state_;
+ State state_;
ServiceListenerError last_error_;
std::vector<Observer*> observers_;
diff --git a/chromium/third_party/openscreen/src/osp/public/service_publisher.h b/chromium/third_party/openscreen/src/osp/public/service_publisher.h
index a5fac6a23fe..b31f59fc066 100644
--- a/chromium/third_party/openscreen/src/osp/public/service_publisher.h
+++ b/chromium/third_party/openscreen/src/osp/public/service_publisher.h
@@ -5,7 +5,6 @@
#ifndef OSP_PUBLIC_SERVICE_PUBLISHER_H_
#define OSP_PUBLIC_SERVICE_PUBLISHER_H_
-#include <atomic>
#include <cstdint>
#include <string>
#include <vector>
@@ -108,7 +107,7 @@ class ServicePublisher {
// By default, all enabled Ethernet and WiFi interfaces are used.
// This configuration must be identical to the interfaces configured
// in the ScreenConnectionServer.
- std::vector<platform::NetworkInterfaceIndex> network_interface_indices;
+ std::vector<NetworkInterfaceIndex> network_interface_indices;
};
virtual ~ServicePublisher();
@@ -145,7 +144,7 @@ class ServicePublisher {
protected:
explicit ServicePublisher(Observer* observer);
- std::atomic<State> state_;
+ State state_;
ServicePublisherError last_error_;
Observer* observer_;
diff --git a/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h b/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h
index 7ba1e36b4ff..cde458c22f2 100644
--- a/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h
+++ b/chromium/third_party/openscreen/src/osp/public/testing/message_demuxer_test_support.h
@@ -22,7 +22,7 @@ class MockMessageCallback final : public MessageDemuxer::MessageCallback {
msgs::Type message_type,
const uint8_t* buffer,
size_t buffer_size,
- platform::Clock::time_point now));
+ Clock::time_point now));
};
} // namespace osp
diff --git a/chromium/third_party/openscreen/src/platform/BUILD.gn b/chromium/third_party/openscreen/src/platform/BUILD.gn
index bd24e0b0090..9545232cf87 100644
--- a/chromium/third_party/openscreen/src/platform/BUILD.gn
+++ b/chromium/third_party/openscreen/src/platform/BUILD.gn
@@ -4,26 +4,11 @@
import("//build_overrides/build.gni")
-source_set("platform") {
+# Source files that depend on nothing (all your base/ are belong to us).
+source_set("base") {
defines = []
sources = [
- "api/logging.h",
- "api/network_interface.h",
- "api/scoped_wake_lock.cc",
- "api/scoped_wake_lock.h",
- "api/socket.h",
- "api/task_runner.h",
- "api/time.h",
- "api/tls_connection.cc",
- "api/tls_connection.h",
- "api/tls_connection_factory.cc",
- "api/tls_connection_factory.h",
- "api/trace_logging_platform.cc",
- "api/trace_logging_platform.h",
- "api/udp_read_callback.h",
- "api/udp_socket.cc",
- "api/udp_socket.h",
"base/error.cc",
"base/error.h",
"base/interface_info.cc",
@@ -46,20 +31,49 @@ source_set("platform") {
"base/udp_packet.h",
]
+ public_configs = [ "../build:openscreen_include_dirs" ]
+}
+
+# Public API source files. These may depend on nothing except :base.
+source_set("api") {
+ defines = []
+
+ sources = [
+ "api/export.h",
+ "api/logging.h",
+ "api/network_interface.h",
+ "api/scoped_wake_lock.cc",
+ "api/scoped_wake_lock.h",
+ "api/task_runner.h",
+ "api/time.h",
+ "api/tls_connection.cc",
+ "api/tls_connection.h",
+ "api/tls_connection_factory.cc",
+ "api/tls_connection_factory.h",
+ "api/trace_logging_platform.cc",
+ "api/trace_logging_platform.h",
+ "api/udp_socket.cc",
+ "api/udp_socket.h",
+ ]
+
public_deps = [
- "../third_party/abseil",
- "../third_party/boringssl",
- "../util",
+ ":base",
]
+}
- public_configs = [ "../build:openscreen_include_dirs" ]
+# The following target is only activated in standalone builds (see :platform).
+if (!build_with_chromium) {
+ source_set("standalone_impl") {
+ defines = []
- if (!build_with_chromium) {
- sources += [
+ sources = [
"impl/logging.h",
+ "impl/network_interface.cc",
+ "impl/network_interface.h",
"impl/socket_handle.h",
"impl/socket_handle_waiter.cc",
"impl/socket_handle_waiter.h",
+ "impl/stream_socket.h",
"impl/task_runner.cc",
"impl/task_runner.h",
"impl/text_trace_logging_platform.cc",
@@ -67,11 +81,14 @@ source_set("platform") {
"impl/time.cc",
"impl/tls_write_buffer.cc",
"impl/tls_write_buffer.h",
- "impl/weak_ptr.h",
]
if (is_linux) {
- sources += [ "impl/network_interface_linux.cc" ]
+ sources += [
+ "impl/network_interface_linux.cc",
+ "impl/scoped_wake_lock_linux.cc",
+ "impl/scoped_wake_lock_linux.h",
+ ]
} else if (is_mac) {
defines += [
# Required, to use the new IPv6 Sockets options introduced by RFC 3542.
@@ -102,7 +119,6 @@ source_set("platform") {
"impl/socket_handle_posix.h",
"impl/socket_handle_waiter_posix.cc",
"impl/socket_handle_waiter_posix.h",
- "impl/stream_socket.h",
"impl/stream_socket_posix.cc",
"impl/stream_socket_posix.h",
"impl/timeval_posix.cc",
@@ -119,9 +135,30 @@ source_set("platform") {
"impl/udp_socket_reader_posix.h",
]
}
+
+ deps = [
+ ":api",
+ "../third_party/abseil",
+ "../third_party/boringssl",
+ "../util",
+ ]
}
}
+# The main target, which either assumes an embedder will link-in the platform
+# API implementation elsewhere, or links-in the :standalone_impl in the build.
+source_set("platform") {
+ public_deps = [
+ ":api",
+ ]
+ if (!build_with_chromium) {
+ deps = [
+ ":standalone_impl",
+ ]
+ }
+}
+
+# Test helpers, referenced in other Open Screen BUILD.gn test targets.
source_set("test") {
testonly = true
sources = [
@@ -132,12 +169,15 @@ source_set("test") {
"test/fake_udp_socket.cc",
"test/fake_udp_socket.h",
"test/mock_tls_connection.h",
+ "test/mock_udp_socket.h",
"test/trace_logging_helpers.h",
]
deps = [
":platform",
+ "../third_party/abseil",
"../third_party/googletest:gmock",
+ "../util",
]
}
@@ -145,6 +185,7 @@ source_set("unittests") {
testonly = true
sources = [
+ "api/socket_integration_unittest.cc",
"api/time_unittest.cc",
"base/error_unittest.cc",
"base/ip_address_unittest.cc",
@@ -155,12 +196,8 @@ source_set("unittests") {
# Exclude them if an embedder is providing the implementation.
if (!build_with_chromium) {
sources += [
- # TODO(jophba): move over to general sources when UDP socket create
- # is implemented in Chromium, as part of the NetworkRunner work.
- "api/socket_integration_unittest.cc",
"impl/task_runner_unittest.cc",
"impl/time_unittest.cc",
- "impl/weak_ptr_unittest.cc",
]
if (is_posix) {
@@ -178,7 +215,11 @@ source_set("unittests") {
deps = [
":platform",
+ ":test",
+ "../third_party/abseil",
+ "../third_party/boringssl",
"../third_party/googletest:gmock",
"../third_party/googletest:gtest",
+ "../util",
]
}
diff --git a/chromium/third_party/openscreen/src/platform/api/DEPS b/chromium/third_party/openscreen/src/platform/api/DEPS
index f0fc67f29bb..f833653aeb9 100644
--- a/chromium/third_party/openscreen/src/platform/api/DEPS
+++ b/chromium/third_party/openscreen/src/platform/api/DEPS
@@ -3,7 +3,20 @@
# found in the LICENSE file.
include_rules = [
+ # Platform API code should depend on no outside code/libraries, other than
+ # the standard toolchain libraries (C, STL) and platform/base.
+ '-absl',
+ '-platform',
+ '+platform/api',
'+platform/base',
'-util',
'-third_party',
]
+
+specific_include_rules = {
+ ".*_unittest\.cc": [
+ '+platform/test',
+ '+util',
+ '+third_party',
+ ],
+}
diff --git a/chromium/third_party/openscreen/src/platform/api/logging.h b/chromium/third_party/openscreen/src/platform/api/logging.h
index 9ff79d1a75a..6080b77c8b2 100644
--- a/chromium/third_party/openscreen/src/platform/api/logging.h
+++ b/chromium/third_party/openscreen/src/platform/api/logging.h
@@ -5,10 +5,9 @@
#ifndef PLATFORM_API_LOGGING_H_
#define PLATFORM_API_LOGGING_H_
-#include "absl/strings/string_view.h"
+#include <sstream>
namespace openscreen {
-namespace platform {
enum class LogLevel {
// Very detailed information, often used for evaluating performance or
@@ -36,17 +35,20 @@ enum class LogLevel {
// Returns true if |level| is at or above the level where the embedder will
// record/emit log entries from the code in |file|.
-bool IsLoggingOn(LogLevel level, absl::string_view file);
+bool IsLoggingOn(LogLevel level, const char* file);
// Record a log entry, consisting of its logging level, location and message.
// The embedder may filter-out entries according to its own policy, but this
// function will not be called if IsLoggingOn(level, file) returns false.
// Whenever |level| is kFatal, Open Screen will call Break() immediately after
// this returns.
+//
+// |message| is passed as a string stream to avoid unnecessary string copies.
+// Embedders can call its rdbuf() or str() methods to access the log message.
void LogWithLevel(LogLevel level,
- absl::string_view file,
+ const char* file,
int line,
- absl::string_view msg);
+ std::stringstream message);
// Breaks into the debugger, if one is present. Otherwise, aborts the current
// process (i.e., this function should not return). In production builds, an
@@ -55,7 +57,6 @@ void LogWithLevel(LogLevel level,
// aborting the process.
void Break();
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_LOGGING_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/network_interface.h b/chromium/third_party/openscreen/src/platform/api/network_interface.h
index f0e53e17b11..28ebd2fed11 100644
--- a/chromium/third_party/openscreen/src/platform/api/network_interface.h
+++ b/chromium/third_party/openscreen/src/platform/api/network_interface.h
@@ -10,7 +10,6 @@
#include "platform/base/interface_info.h"
namespace openscreen {
-namespace platform {
// Returns an InterfaceInfo for each currently active network interface on the
// system. No two entries in this vector can have the same NetworkInterfaceIndex
@@ -22,7 +21,6 @@ namespace platform {
// discovery) are not being used.
std::vector<InterfaceInfo> GetNetworkInterfaces();
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_NETWORK_INTERFACE_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc
index ad8442e091f..73a946e1e9a 100644
--- a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc
+++ b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.cc
@@ -5,10 +5,8 @@
#include "platform/api/scoped_wake_lock.h"
namespace openscreen {
-namespace platform {
ScopedWakeLock::ScopedWakeLock() = default;
ScopedWakeLock::~ScopedWakeLock() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h
index 64d69ecb867..2843be4eea5 100644
--- a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h
+++ b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h
@@ -8,7 +8,6 @@
#include <memory>
namespace openscreen {
-namespace platform {
// Ensures that the device does not got to sleep. This is used, for example,
// while Open Screen is communicating with peers over the network for things
@@ -33,7 +32,6 @@ class ScopedWakeLock {
virtual ~ScopedWakeLock();
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_SCOPED_WAKE_LOCK_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc
index fe15377ba56..3704dbcbee3 100644
--- a/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc
@@ -9,7 +9,6 @@
#include "platform/test/fake_udp_socket.h"
namespace openscreen {
-namespace platform {
using testing::_;
@@ -21,7 +20,7 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) {
FakeClock clock(Clock::now());
FakeTaskRunner task_runner(&clock);
FakeUdpSocket::MockClient client;
- ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create(
+ ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create(
&task_runner, &client, IPEndpoint{IPAddress(kIpV4AddrAny), 0});
ASSERT_TRUE(create_result) << create_result.error();
const auto socket = std::move(create_result.value());
@@ -35,11 +34,11 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) {
// successfully Bind(), and that the operating system will return the
// auto-assigned socket name (i.e., the local endpoint's port will not be zero).
TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) {
- const uint8_t kIpV6AddrAny[16] = {};
+ const uint16_t kIpV6AddrAny[8] = {};
FakeClock clock(Clock::now());
FakeTaskRunner task_runner(&clock);
FakeUdpSocket::MockClient client;
- ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create(
+ ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create(
&task_runner, &client, IPEndpoint{IPAddress(kIpV6AddrAny), 0});
ASSERT_TRUE(create_result) << create_result.error();
const auto socket = std::move(create_result.value());
@@ -49,5 +48,4 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) {
EXPECT_NE(local_endpoint.port, 0) << local_endpoint;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/task_runner.h b/chromium/third_party/openscreen/src/platform/api/task_runner.h
index c94777fd0a6..e114db68860 100644
--- a/chromium/third_party/openscreen/src/platform/api/task_runner.h
+++ b/chromium/third_party/openscreen/src/platform/api/task_runner.h
@@ -5,21 +5,21 @@
#ifndef PLATFORM_API_TASK_RUNNER_H_
#define PLATFORM_API_TASK_RUNNER_H_
-#include <future>
+#include <future> // NOLINT
+#include <utility>
-#include "absl/types/optional.h"
#include "platform/api/time.h"
namespace openscreen {
-namespace platform {
// A thread-safe API surface that allows for posting tasks. The underlying
// implementation may be single or multi-threaded, and all complication should
-// be handled by the implementation class. It is the expectation of this API
-// that the underlying impl gives the following guarantees:
+// be handled by the implementation class. The implementation must guarantee:
// (1) Tasks shall not overlap in time/CPU.
// (2) Tasks shall run sequentially, e.g. posting task A then B implies
// that A shall run before B.
+// (3) If task A is posted before task B, then any mutation in A happens-before
+// B runs (even if A and B run on different threads).
class TaskRunner {
public:
using Task = std::packaged_task<void() noexcept>;
@@ -50,10 +50,9 @@ class TaskRunner {
// Return true if the calling thread is the thread that task runner is using
// to run tasks, false otherwise.
- virtual bool IsRunningOnTaskRunner() { return true; }
+ virtual bool IsRunningOnTaskRunner() = 0;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_TASK_RUNNER_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/time.h b/chromium/third_party/openscreen/src/platform/api/time.h
index af0dda7573d..dcf24accc88 100644
--- a/chromium/third_party/openscreen/src/platform/api/time.h
+++ b/chromium/third_party/openscreen/src/platform/api/time.h
@@ -10,7 +10,6 @@
#include "platform/base/trivial_clock_traits.h"
namespace openscreen {
-namespace platform {
// The "reasonably high-resolution" source of monotonic time from the embedder,
// exhibiting the traits described in TrivialClockTraits. This class is not
@@ -33,7 +32,6 @@ class Clock : public TrivialClockTraits {
// time."
std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept;
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_TIME_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc
index 9a4f40cbd10..2bef8250f04 100644
--- a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc
@@ -12,7 +12,6 @@ using std::chrono::microseconds;
using std::chrono::milliseconds;
namespace openscreen {
-namespace platform {
namespace {
// Tests that the clock always seems to tick forward. If this test is broken, or
@@ -56,5 +55,4 @@ TEST(TimeTest, PlatformClockHasSufficientResolution) {
}
} // namespace
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection.cc b/chromium/third_party/openscreen/src/platform/api/tls_connection.cc
index fc3a8c8ebef..9668c114512 100644
--- a/chromium/third_party/openscreen/src/platform/api/tls_connection.cc
+++ b/chromium/third_party/openscreen/src/platform/api/tls_connection.cc
@@ -5,10 +5,8 @@
#include "platform/api/tls_connection.h"
namespace openscreen {
-namespace platform {
TlsConnection::TlsConnection() = default;
TlsConnection::~TlsConnection() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection.h b/chromium/third_party/openscreen/src/platform/api/tls_connection.h
index 7c2777fb204..4d409cb54f9 100644
--- a/chromium/third_party/openscreen/src/platform/api/tls_connection.h
+++ b/chromium/third_party/openscreen/src/platform/api/tls_connection.h
@@ -12,20 +12,12 @@
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
class TlsConnection {
public:
// Client callbacks are run via the TaskRunner used by TlsConnectionFactory.
class Client {
public:
- // Called when |connection| writing is blocked and unblocked, respectively.
- // Note that implementations should do best effort to buffer packets even in
- // blocked state, and should call OnError if we actually overflow the
- // buffer.
- virtual void OnWriteBlocked(TlsConnection* connection) = 0;
- virtual void OnWriteUnblocked(TlsConnection* connection) = 0;
-
// Called when |connection| experiences an error, such as a read error.
virtual void OnError(TlsConnection* connection, Error error) = 0;
@@ -45,8 +37,8 @@ class TlsConnection {
// the Client.
virtual void SetClient(Client* client) = 0;
- // Sends a message.
- virtual void Write(const void* data, size_t len) = 0;
+ // Sends a message. Returns true iff the message will be sent.
+ [[nodiscard]] virtual bool Send(const void* data, size_t len) = 0;
// Get the local address.
virtual IPEndpoint GetLocalEndpoint() const = 0;
@@ -58,7 +50,6 @@ class TlsConnection {
TlsConnection();
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_TLS_CONNECTION_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc
index a3bfb287962..e64078f1727 100644
--- a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc
+++ b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.cc
@@ -5,10 +5,8 @@
#include "platform/api/tls_connection_factory.h"
namespace openscreen {
-namespace platform {
TlsConnectionFactory::TlsConnectionFactory() = default;
TlsConnectionFactory::~TlsConnectionFactory() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h
index eed6893f404..80dc8ac6367 100644
--- a/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h
+++ b/chromium/third_party/openscreen/src/platform/api/tls_connection_factory.h
@@ -13,7 +13,6 @@
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
class TaskRunner;
class TlsConnection;
@@ -30,7 +29,7 @@ class TlsConnectionFactory {
public:
// Provides a new |connection| that resulted from listening on the local
// socket. |der_x509_peer_cert| is the DER-encoded X509 certificate from the
- // peer.
+ // peer if present, or empty if the peer didn't provide one.
virtual void OnAccepted(TlsConnectionFactory* factory,
std::vector<uint8_t> der_x509_peer_cert,
std::unique_ptr<TlsConnection> connection) = 0;
@@ -75,7 +74,6 @@ class TlsConnectionFactory {
TlsConnectionFactory();
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_TLS_CONNECTION_FACTORY_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc
index 2725b404958..35be2bf329b 100644
--- a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc
+++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.cc
@@ -5,9 +5,7 @@
#include "platform/api/trace_logging_platform.h"
namespace openscreen {
-namespace platform {
TraceLoggingPlatform::~TraceLoggingPlatform() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h
index 76c7e7a974d..882c1a7db6c 100644
--- a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h
+++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h
@@ -11,12 +11,13 @@
#include "platform/base/trace_logging_types.h"
namespace openscreen {
-namespace platform {
// Optional platform API to support logging trace events from Open Screen. To
// use this, implement the TraceLoggingPlatform interface and call
// StartTracing() and StopTracing() to turn tracing on/off (see
// platform/base/trace_logging_activation.h).
+//
+// All methods must be thread-safe and re-entrant.
class TraceLoggingPlatform {
public:
virtual ~TraceLoggingPlatform();
@@ -48,7 +49,6 @@ class TraceLoggingPlatform {
Error::Code error) = 0;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_TRACE_LOGGING_PLATFORM_H_
diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc
index 60262cab8c4..2da7a73f429 100644
--- a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc
+++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc
@@ -5,10 +5,8 @@
#include "platform/api/udp_socket.h"
namespace openscreen {
-namespace platform {
UdpSocket::UdpSocket() = default;
UdpSocket::~UdpSocket() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.h b/chromium/third_party/openscreen/src/platform/api/udp_socket.h
index e76b9eb8586..e668db12a50 100644
--- a/chromium/third_party/openscreen/src/platform/api/udp_socket.h
+++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.h
@@ -5,11 +5,10 @@
#ifndef PLATFORM_API_UDP_SOCKET_H_
#define PLATFORM_API_UDP_SOCKET_H_
-#include <atomic>
-#include <cstdint>
-#include <functional>
+#include <stddef.h> // size_t
+#include <stdint.h> // uint8_t
+
#include <memory>
-#include <mutex>
#include "platform/api/network_interface.h"
#include "platform/base/error.h"
@@ -17,12 +16,8 @@
#include "platform/base/udp_packet.h"
namespace openscreen {
-namespace platform {
class TaskRunner;
-class UdpSocket;
-
-using UdpSocketUniquePtr = std::unique_ptr<UdpSocket>;
// An open UDP socket for sending/receiving datagrams to/from either specific
// endpoints or over IP multicast.
@@ -75,9 +70,10 @@ class UdpSocket {
// will be queued on the provided |task_runner|. For this reason, the provided
// TaskRunner and Client must exist for the duration of the created socket's
// lifetime.
- static ErrorOr<UdpSocketUniquePtr> Create(TaskRunner* task_runner,
- Client* client,
- const IPEndpoint& local_endpoint);
+ static ErrorOr<std::unique_ptr<UdpSocket>> Create(
+ TaskRunner* task_runner,
+ Client* client,
+ const IPEndpoint& local_endpoint);
virtual ~UdpSocket();
@@ -119,7 +115,6 @@ class UdpSocket {
UdpSocket();
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_API_UDP_SOCKET_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/DEPS b/chromium/third_party/openscreen/src/platform/base/DEPS
index 46dad90c1db..f4c412d5eec 100644
--- a/chromium/third_party/openscreen/src/platform/base/DEPS
+++ b/chromium/third_party/openscreen/src/platform/base/DEPS
@@ -3,13 +3,18 @@
# found in the LICENSE file.
include_rules = [
- '-platform'
+ # Platform base code should depend on no outside code/libraries, other than
+ # the standard toolchain libraries (C, STL).
+ '-absl',
+ '-platform',
+ '+platform/base',
'-util',
'-third_party',
]
specific_include_rules = {
".*_unittest\.cc": [
+ '+platform/test',
'+util',
'+third_party',
],
diff --git a/chromium/third_party/openscreen/src/platform/base/error.cc b/chromium/third_party/openscreen/src/platform/base/error.cc
index fd300c1abb3..ea58c91ba49 100644
--- a/chromium/third_party/openscreen/src/platform/base/error.cc
+++ b/chromium/third_party/openscreen/src/platform/base/error.cc
@@ -4,6 +4,8 @@
#include "platform/base/error.h"
+#include <sstream>
+
namespace openscreen {
Error::Error() = default;
@@ -120,6 +122,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) {
return os << "Failure: FatalSSLError";
case Error::Code::kRSAKeyGenerationFailure:
return os << "Failure: RSAKeyGenerationFailure";
+ case Error::Code::kRSAKeyParseError:
+ return os << "Failure: RSAKeyParseError";
case Error::Code::kEVPInitializationError:
return os << "Failure: EVPInitializationError";
case Error::Code::kCertificateCreationError:
@@ -166,6 +170,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) {
return os << "Failure: ItemNotFound";
case Error::Code::kOperationInvalid:
return os << "Failure: OperationInvalid";
+ case Error::Code::kOperationInProgress:
+ return os << "Failure: OperationInProgress";
case Error::Code::kOperationCancelled:
return os << "Failure: OperationCancelled";
case Error::Code::kCastV2PeerCertEmpty:
@@ -220,12 +226,24 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) {
return os << "Failure: kCastV2PingTimeout";
case Error::Code::kCastV2ChannelPolicyMismatch:
return os << "Failure: kCastV2ChannelPolicyMismatch";
+ case Error::Code::kCreateSignatureFailed:
+ return os << "Failure: kCreateSignatureFailed";
+ case Error::Code::kUpdateReceivedRecordFailure:
+ return os << "Failure: kUpdateReceivedRecordFailure";
+ case Error::Code::kRecordPublicationError:
+ return os << "Failure: kRecordPublicationError";
}
// Unused 'return' to get around failure on GCC.
return os;
}
+std::string Error::ToString() const {
+ std::stringstream ss;
+ ss << *this;
+ return ss.str();
+}
+
std::ostream& operator<<(std::ostream& out, const Error& error) {
out << error.code() << " = \"" << error.message() << "\"";
return out;
diff --git a/chromium/third_party/openscreen/src/platform/base/error.h b/chromium/third_party/openscreen/src/platform/base/error.h
index aedbf29a2a0..aa79a81567c 100644
--- a/chromium/third_party/openscreen/src/platform/base/error.h
+++ b/chromium/third_party/openscreen/src/platform/base/error.h
@@ -5,11 +5,11 @@
#ifndef PLATFORM_BASE_ERROR_H_
#define PLATFORM_BASE_ERROR_H_
+#include <cassert>
#include <ostream>
#include <string>
#include <utility>
-#include "absl/types/variant.h"
#include "platform/base/macros.h"
namespace openscreen {
@@ -81,6 +81,7 @@ class Error {
// Was unable to generate an RSA key.
kRSAKeyGenerationFailure,
+ kRSAKeyParseError,
// Was unable to initialize an EVP_PKEY type.
kEVPInitializationError,
@@ -155,6 +156,12 @@ class Error {
kCastV2PingTimeout,
kCastV2ChannelPolicyMismatch,
+ kCreateSignatureFailed,
+
+ // Discovery errors.
+ kUpdateReceivedRecordFailure,
+ kRecordPublicationError,
+
// Generic errors.
kUnknownError,
kNotImplemented,
@@ -166,6 +173,7 @@ class Error {
kItemAlreadyExists,
kItemNotFound,
kOperationInvalid,
+ kOperationInProgress,
kOperationCancelled,
};
@@ -186,9 +194,12 @@ class Error {
Code code() const { return code_; }
const std::string& message() const { return message_; }
+ std::string& message() { return message_; }
static const Error& None();
+ std::string ToString() const;
+
private:
Code code_ = Code::kNone;
std::string message_;
@@ -222,35 +233,89 @@ class ErrorOr {
return error;
}
- ErrorOr(ErrorOr&& other) = default;
- ErrorOr& operator=(ErrorOr&& other) = default;
+ ErrorOr(const ValueType& value) : value_(value), is_value_(true) {} // NOLINT
+ ErrorOr(ValueType&& value) noexcept // NOLINT
+ : value_(std::move(value)), is_value_(true) {}
- ErrorOr(const ValueType& value) : variant_{value} {} // NOLINT
- ErrorOr(ValueType&& value) noexcept // NOLINT
- : variant_{std::move(value)} {}
-
- ErrorOr(Error error) : variant_{std::move(error)} {} // NOLINT
- ErrorOr(Error::Code code) : variant_{code} {} // NOLINT
+ ErrorOr(const Error& error) : error_(error), is_value_(false) { // NOLINT
+ assert(error_.code() != Error::Code::kNone);
+ }
+ ErrorOr(Error&& error) noexcept // NOLINT
+ : error_(std::move(error)), is_value_(false) {
+ assert(error_.code() != Error::Code::kNone);
+ }
+ ErrorOr(Error::Code code) : error_(code), is_value_(false) { // NOLINT
+ assert(error_.code() != Error::Code::kNone);
+ }
ErrorOr(Error::Code code, std::string message)
- : variant_{Error{code, std::move(message)}} {}
+ : error_(code, std::move(message)), is_value_(false) {
+ assert(error_.code() != Error::Code::kNone);
+ }
- ~ErrorOr() = default;
+ ErrorOr(ErrorOr&& other) noexcept : is_value_(other.is_value_) {
+ // NB: Both |value_| and |error_| are uninitialized memory at this point!
+ // Unlike the other constructors, the compiler will not auto-generate
+ // constructor calls for either union member because neither appeared in
+ // this constructor's initializer list.
+ if (other.is_value_) {
+ new (&value_) ValueType(std::move(other.value_));
+ } else {
+ new (&error_) Error(std::move(other.error_));
+ }
+ }
- bool is_error() const { return absl::holds_alternative<Error>(variant_); }
- bool is_value() const { return !is_error(); }
+ ErrorOr& operator=(ErrorOr&& other) noexcept {
+ this->~ErrorOr<ValueType>();
+ new (this) ErrorOr<ValueType>(std::move(other));
+ return *this;
+ }
+
+ ~ErrorOr() {
+ // NB: |value_| or |error_| must be explicitly destroyed since the compiler
+ // will not auto-generate the destructor calls for union members.
+ if (is_value_) {
+ value_.~ValueType();
+ } else {
+ error_.~Error();
+ }
+ }
+
+ bool is_error() const { return !is_value_; }
+ bool is_value() const { return is_value_; }
// Unlike Error, we CAN provide an operator bool here, since it is
// more obvious to callers that ErrorOr<Foo> will be true if it's Foo.
- operator bool() const { return is_value(); }
+ operator bool() const { return is_value_; }
- const Error& error() const { return absl::get<Error>(variant_); }
- Error& error() { return absl::get<Error>(variant_); }
+ const Error& error() const {
+ assert(!is_value_);
+ return error_;
+ }
+ Error& error() {
+ assert(!is_value_);
+ return error_;
+ }
- const ValueType& value() const { return absl::get<ValueType>(variant_); }
- ValueType& value() { return absl::get<ValueType>(variant_); }
+ const ValueType& value() const {
+ assert(is_value_);
+ return value_;
+ }
+ ValueType& value() {
+ assert(is_value_);
+ return value_;
+ }
private:
- absl::variant<Error, ValueType> variant_;
+ // Only one of these is an active member, determined by |is_value_|. Since
+ // they are union'ed, they must be explicitly constructed and destroyed.
+ union {
+ ValueType value_;
+ Error error_;
+ };
+
+ // If true, |value_| is initialized and active. Otherwise, |error_| is
+ // initialized and active.
+ const bool is_value_;
};
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/interface_info.cc b/chromium/third_party/openscreen/src/platform/base/interface_info.cc
index b8baff97672..2ada91be943 100644
--- a/chromium/third_party/openscreen/src/platform/base/interface_info.cc
+++ b/chromium/third_party/openscreen/src/platform/base/interface_info.cc
@@ -4,8 +4,9 @@
#include "platform/base/interface_info.h"
+#include <algorithm>
+
namespace openscreen {
-namespace platform {
InterfaceInfo::InterfaceInfo() = default;
InterfaceInfo::InterfaceInfo(NetworkInterfaceIndex index,
@@ -27,6 +28,24 @@ IPSubnet::IPSubnet(IPAddress address, uint8_t prefix_length)
: address(std::move(address)), prefix_length(prefix_length) {}
IPSubnet::~IPSubnet() = default;
+IPAddress InterfaceInfo::GetIpAddressV4() const {
+ for (const auto& address : addresses) {
+ if (address.address.IsV4()) {
+ return address.address;
+ }
+ }
+ return IPAddress{};
+}
+
+IPAddress InterfaceInfo::GetIpAddressV6() const {
+ for (const auto& address : addresses) {
+ if (address.address.IsV6()) {
+ return address.address;
+ }
+ }
+ return IPAddress{};
+}
+
std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet) {
if (subnet.address.IsV6()) {
out << '[';
@@ -38,23 +57,30 @@ std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet) {
return out << '/' << std::dec << static_cast<int>(subnet.prefix_length);
}
-std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) {
- std::string media_type;
- switch (info.type) {
+std::ostream& operator<<(std::ostream& out, InterfaceInfo::Type type) {
+ switch (type) {
case InterfaceInfo::Type::kEthernet:
- media_type = "Ethernet";
+ out << "Ethernet";
break;
case InterfaceInfo::Type::kWifi:
- media_type = "Wifi";
+ out << "Wifi";
+ break;
+ case InterfaceInfo::Type::kLoopback:
+ out << "Loopback";
break;
case InterfaceInfo::Type::kOther:
- media_type = "Other";
+ out << "Other";
break;
}
+
+ return out;
+}
+
+std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) {
out << '{' << info.index << " (a.k.a. " << info.name
- << "); media_type=" << media_type << "; MAC=" << std::hex
+ << "); media_type=" << info.type << "; MAC=" << std::hex
<< static_cast<int>(info.hardware_address[0]);
- for (size_t i = 1; i < sizeof(info.hardware_address); ++i) {
+ for (size_t i = 1; i < info.hardware_address.size(); ++i) {
out << ':' << static_cast<int>(info.hardware_address[i]);
}
for (const IPSubnet& ip : info.addresses) {
@@ -63,5 +89,4 @@ std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info) {
return out << '}';
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/interface_info.h b/chromium/third_party/openscreen/src/platform/base/interface_info.h
index af0c35b6caa..81686063ba7 100644
--- a/chromium/third_party/openscreen/src/platform/base/interface_info.h
+++ b/chromium/third_party/openscreen/src/platform/base/interface_info.h
@@ -13,7 +13,6 @@
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
// Unique identifier, usually provided by the operating system, for identifying
// a specific network interface. This value is used with UdpSocket to join
@@ -40,18 +39,14 @@ struct IPSubnet {
};
struct InterfaceInfo {
- enum class Type {
- kEthernet = 0,
- kWifi,
- kOther,
- };
+ enum class Type : uint32_t { kEthernet = 0, kWifi, kLoopback, kOther };
// Interface index, typically as specified by the operating system,
// identifying this interface on the host machine.
NetworkInterfaceIndex index = kInvalidNetworkInterfaceIndex;
// MAC address of the interface. All 0s if unavailable.
- uint8_t hardware_address[6] = {};
+ std::array<uint8_t, 6> hardware_address = {};
// Interface name (e.g. eth0) if available.
std::string name;
@@ -62,6 +57,12 @@ struct InterfaceInfo {
// All IP addresses associated with the interface.
std::vector<IPSubnet> addresses;
+ // Returns an IPAddress of the given type associated with this network
+ // interface, or the false IPAddress if the associated address family is not
+ // supported on this interface.
+ IPAddress GetIpAddressV4() const;
+ IPAddress GetIpAddressV6() const;
+
InterfaceInfo();
InterfaceInfo(NetworkInterfaceIndex index,
const uint8_t hardware_address[6],
@@ -72,10 +73,10 @@ struct InterfaceInfo {
};
// Human-readable output (e.g., for logging).
+std::ostream& operator<<(std::ostream& out, InterfaceInfo::Type type);
std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet);
std::ostream& operator<<(std::ostream& out, const InterfaceInfo& info);
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_INTERFACE_INFO_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.cc b/chromium/third_party/openscreen/src/platform/base/ip_address.cc
index 808b40be7e6..6fee6a3c9c1 100644
--- a/chromium/third_party/openscreen/src/platform/base/ip_address.cc
+++ b/chromium/third_party/openscreen/src/platform/base/ip_address.cc
@@ -4,20 +4,24 @@
#include "platform/base/ip_address.h"
+#include <algorithm>
#include <cassert>
+#include <cctype>
+#include <cinttypes>
+#include <cstdio>
#include <cstring>
#include <iomanip>
-
-#include "absl/types/optional.h"
+#include <iterator>
+#include <sstream>
+#include <utility>
namespace openscreen {
// static
-ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) {
- ErrorOr<IPAddress> v4 = ParseV4(s);
+const IPAddress IPAddress::kV4LoopbackAddress{127, 0, 0, 1};
- return v4 ? std::move(v4) : ParseV6(s);
-} // namespace openscreen
+// static
+const IPAddress IPAddress::kV6LoopbackAddress{0, 0, 0, 0, 0, 0, 0, 1};
IPAddress::IPAddress() : version_(Version::kV4), bytes_({}) {}
IPAddress::IPAddress(const std::array<uint8_t, 4>& bytes)
@@ -35,31 +39,55 @@ IPAddress::IPAddress(Version version, const uint8_t* b) : version_(version) {
}
IPAddress::IPAddress(uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4)
: version_(Version::kV4), bytes_{{b1, b2, b3, b4}} {}
-IPAddress::IPAddress(const std::array<uint8_t, 16>& bytes)
- : version_(Version::kV6), bytes_(bytes) {}
-IPAddress::IPAddress(const uint8_t (&b)[16])
- : version_(Version::kV6),
- bytes_{{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10],
- b[11], b[12], b[13], b[14], b[15]}} {}
-IPAddress::IPAddress(uint8_t b1,
- uint8_t b2,
- uint8_t b3,
- uint8_t b4,
- uint8_t b5,
- uint8_t b6,
- uint8_t b7,
- uint8_t b8,
- uint8_t b9,
- uint8_t b10,
- uint8_t b11,
- uint8_t b12,
- uint8_t b13,
- uint8_t b14,
- uint8_t b15,
- uint8_t b16)
+
+IPAddress::IPAddress(const std::array<uint16_t, 8>& hextets)
+ : IPAddress(hextets[0],
+ hextets[1],
+ hextets[2],
+ hextets[3],
+ hextets[4],
+ hextets[5],
+ hextets[6],
+ hextets[7]) {}
+
+IPAddress::IPAddress(const uint16_t (&hextets)[8])
+ : IPAddress(hextets[0],
+ hextets[1],
+ hextets[2],
+ hextets[3],
+ hextets[4],
+ hextets[5],
+ hextets[6],
+ hextets[7]) {}
+
+IPAddress::IPAddress(uint16_t h0,
+ uint16_t h1,
+ uint16_t h2,
+ uint16_t h3,
+ uint16_t h4,
+ uint16_t h5,
+ uint16_t h6,
+ uint16_t h7)
: version_(Version::kV6),
- bytes_{{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15,
- b16}} {}
+ bytes_{{
+ static_cast<uint8_t>(h0 >> 8),
+ static_cast<uint8_t>(h0),
+ static_cast<uint8_t>(h1 >> 8),
+ static_cast<uint8_t>(h1),
+ static_cast<uint8_t>(h2 >> 8),
+ static_cast<uint8_t>(h2),
+ static_cast<uint8_t>(h3 >> 8),
+ static_cast<uint8_t>(h3),
+ static_cast<uint8_t>(h4 >> 8),
+ static_cast<uint8_t>(h4),
+ static_cast<uint8_t>(h5 >> 8),
+ static_cast<uint8_t>(h5),
+ static_cast<uint8_t>(h6 >> 8),
+ static_cast<uint8_t>(h6),
+ static_cast<uint8_t>(h7 >> 8),
+ static_cast<uint8_t>(h7),
+ }} {}
+
IPAddress::IPAddress(const IPAddress& o) noexcept = default;
IPAddress::IPAddress(IPAddress&& o) noexcept = default;
IPAddress& IPAddress::operator=(const IPAddress& o) noexcept = default;
@@ -101,115 +129,136 @@ void IPAddress::CopyToV6(uint8_t x[16]) const {
std::memcpy(x, bytes_.data(), 16);
}
-// static
-ErrorOr<IPAddress> IPAddress::ParseV4(const std::string& s) {
- if (s.size() > 0 && s[0] == '.')
+namespace {
+
+ErrorOr<IPAddress> ParseV4(const std::string& s) {
+ int octets[4];
+ int chars_scanned;
+ // Note: sscanf()'s parsing for %d allows leading whitespace; so the invalid
+ // presence of whitespace must be explicitly checked too.
+ if (std::any_of(s.begin(), s.end(), [](char c) { return std::isspace(c); }) ||
+ sscanf(s.c_str(), "%3d.%3d.%3d.%3d%n", &octets[0], &octets[1], &octets[2],
+ &octets[3], &chars_scanned) != 4 ||
+ chars_scanned != static_cast<int>(s.size()) ||
+ std::any_of(std::begin(octets), std::end(octets),
+ [](int octet) { return octet < 0 || octet > 255; })) {
return Error::Code::kInvalidIPV4Address;
+ }
+ return IPAddress(octets[0], octets[1], octets[2], octets[3]);
+}
- IPAddress address;
- uint16_t next_octet = 0;
- int i = 0;
- bool previous_dot = false;
- for (auto c : s) {
- if (c == '.') {
- if (previous_dot) {
- return Error::Code::kInvalidIPV4Address;
- }
- address.bytes_[i++] = static_cast<uint8_t>(next_octet);
- next_octet = 0;
- previous_dot = true;
- if (i > 3)
- return Error::Code::kInvalidIPV4Address;
-
- continue;
- }
- previous_dot = false;
- if (!std::isdigit(c))
- return Error::Code::kInvalidIPV4Address;
-
- next_octet = next_octet * 10 + (c - '0');
- if (next_octet > 255)
- return Error::Code::kInvalidIPV4Address;
+// Returns the zero-expansion of a double-colon in |s| if |s| is a
+// well-formatted IPv6 address. If |s| is ill-formatted, returns *any* string
+// that is ill-formatted.
+std::string ExpandIPv6DoubleColon(const std::string& s) {
+ constexpr char kDoubleColon[] = "::";
+ const size_t double_colon_position = s.find(kDoubleColon);
+ if (double_colon_position == std::string::npos) {
+ return s; // Nothing to expand.
+ }
+ if (double_colon_position != s.rfind(kDoubleColon)) {
+ return {}; // More than one occurrence of double colons is illegal.
}
- if (previous_dot)
- return Error::Code::kInvalidIPV4Address;
- if (i != 3)
- return Error::Code::kInvalidIPV4Address;
+ std::ostringstream expanded;
+ const int num_single_colons = std::count(s.begin(), s.end(), ':') - 2;
+ int num_zero_groups_to_insert = 8 - num_single_colons;
+ if (double_colon_position != 0) {
+ // abcd:0123:4567::f000:1
+ // ^^^^^^^^^^^^^^^
+ expanded << s.substr(0, double_colon_position + 1);
+ --num_zero_groups_to_insert;
+ }
+ if (double_colon_position != (s.size() - 2)) {
+ --num_zero_groups_to_insert;
+ }
+ while (--num_zero_groups_to_insert > 0) {
+ expanded << "0:";
+ }
+ expanded << '0';
+ if (double_colon_position != (s.size() - 2)) {
+ // abcd:0123:4567::f000:1
+ // ^^^^^^^
+ expanded << s.substr(double_colon_position + 1);
+ }
+ return expanded.str();
+}
- address.bytes_[i] = static_cast<uint8_t>(next_octet);
- address.version_ = Version::kV4;
- return address;
+ErrorOr<IPAddress> ParseV6(const std::string& s) {
+ const std::string scan_input = ExpandIPv6DoubleColon(s);
+ uint16_t hextets[8];
+ int chars_scanned;
+ // Note: sscanf()'s parsing for %x allows leading whitespace; so the invalid
+ // presence of whitespace must be explicitly checked too.
+ if (std::any_of(s.begin(), s.end(), [](char c) { return std::isspace(c); }) ||
+ sscanf(scan_input.c_str(),
+ "%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16
+ ":%4" SCNx16 ":%4" SCNx16 ":%4" SCNx16 "%n",
+ &hextets[0], &hextets[1], &hextets[2], &hextets[3], &hextets[4],
+ &hextets[5], &hextets[6], &hextets[7], &chars_scanned) != 8 ||
+ chars_scanned != static_cast<int>(scan_input.size())) {
+ return Error::Code::kInvalidIPV6Address;
+ }
+ return IPAddress(hextets);
}
+} // namespace
+
// static
-ErrorOr<IPAddress> IPAddress::ParseV6(const std::string& s) {
- if (s.size() > 1 && s[0] == ':' && s[1] != ':')
- return Error::Code::kInvalidIPV6Address;
+ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) {
+ ErrorOr<IPAddress> v4 = ParseV4(s);
- uint16_t next_value = 0;
- uint8_t values[16];
- int i = 0;
- int num_previous_colons = 0;
- absl::optional<int> double_colon_index = absl::nullopt;
- for (auto c : s) {
- if (c == ':') {
- ++num_previous_colons;
- if (num_previous_colons == 2) {
- if (double_colon_index) {
- return Error::Code::kInvalidIPV6Address;
- }
- double_colon_index = i;
- } else if (i >= 15 || num_previous_colons > 2) {
- return Error::Code::kInvalidIPV6Address;
- } else {
- values[i++] = static_cast<uint8_t>(next_value >> 8);
- values[i++] = static_cast<uint8_t>(next_value & 0xff);
- next_value = 0;
- }
- } else {
- num_previous_colons = 0;
- uint8_t x = 0;
- if (c >= '0' && c <= '9') {
- x = c - '0';
- } else if (c >= 'a' && c <= 'f') {
- x = c - 'a' + 10;
- } else if (c >= 'A' && c <= 'F') {
- x = c - 'A' + 10;
- } else {
- return Error::Code::kInvalidIPV6Address;
- }
- if (next_value & 0xf000) {
- return Error::Code::kInvalidIPV6Address;
- } else {
- next_value = static_cast<uint16_t>(next_value * 16 + x);
- }
- }
- }
- if (num_previous_colons == 1)
- return Error::Code::kInvalidIPV6Address;
+ return v4 ? std::move(v4) : ParseV6(s);
+}
- if (i >= 15)
- return Error::Code::kInvalidIPV6Address;
+IPEndpoint::operator bool() const {
+ return address || port;
+}
- values[i++] = static_cast<uint8_t>(next_value >> 8);
- values[i] = static_cast<uint8_t>(next_value & 0xff);
- if (!((i == 15 && !double_colon_index) || (i < 14 && double_colon_index))) {
- return Error::Code::kInvalidIPV6Address;
+// static
+ErrorOr<IPEndpoint> IPEndpoint::Parse(const std::string& s) {
+ // Look for the colon that separates the IP address from the port number. Note
+ // that this check also guards against the case where |s| is the empty string.
+ const auto colon_pos = s.rfind(':');
+ if (colon_pos == std::string::npos) {
+ return Error(Error::Code::kParseError, "missing colon separator");
+ }
+ // The colon cannot be the first nor the last character in |s| because that
+ // would mean there is no address part or port part.
+ if (colon_pos == 0) {
+ return Error(Error::Code::kParseError, "missing address before colon");
+ }
+ if (colon_pos == (s.size() - 1)) {
+ return Error(Error::Code::kParseError, "missing port after colon");
}
- IPAddress address;
- for (int j = 15; j >= 0;) {
- if (double_colon_index && (i == double_colon_index)) {
- address.bytes_[j--] = values[i--];
- while (j > i)
- address.bytes_[j--] = 0;
- } else {
- address.bytes_[j--] = values[i--];
- }
+ ErrorOr<IPAddress> address(Error::Code::kParseError);
+ if (s[0] == '[' && s[colon_pos - 1] == ']') {
+ // [abcd:beef:1:1::2600]:8080
+ // ^^^^^^^^^^^^^^^^^^^^^
+ address = ParseV6(s.substr(1, colon_pos - 2));
+ } else {
+ // 127.0.0.1:22
+ // ^^^^^^^^^
+ address = ParseV4(s.substr(0, colon_pos));
+ }
+ if (address.is_error()) {
+ return Error(Error::Code::kParseError, "invalid address part");
+ }
+
+ const char* const port_part = s.c_str() + colon_pos + 1;
+ int port, chars_scanned;
+ // Note: sscanf()'s parsing for %d allows leading whitespace. Thus, if the
+ // first char is not whitespace, a successful sscanf() parse here can only
+ // mean numerical chars contributed to the parsed integer.
+ if (std::isspace(port_part[0]) ||
+ sscanf(port_part, "%d%n", &port, &chars_scanned) != 1 ||
+ port_part[chars_scanned] != '\0' || port < 0 ||
+ port > std::numeric_limits<uint16_t>::max()) {
+ return Error(Error::Code::kParseError, "invalid port part");
}
- address.version_ = Version::kV6;
- return address;
+
+ return IPEndpoint{address.value(), static_cast<uint16_t>(port)};
}
bool operator==(const IPEndpoint& a, const IPEndpoint& b) {
@@ -220,19 +269,23 @@ bool operator!=(const IPEndpoint& a, const IPEndpoint& b) {
return !(a == b);
}
-bool IPEndpointComparator::operator()(const IPEndpoint& a,
- const IPEndpoint& b) const {
- if (a.address.version() != b.address.version())
- return a.address.version() < b.address.version();
- if (a.address.IsV4()) {
- int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 4);
- if (ret != 0)
- return ret < 0;
+bool IPAddress::operator<(const IPAddress& other) const {
+ if (version() != other.version()) {
+ return version() < other.version();
+ }
+
+ if (IsV4()) {
+ return memcmp(bytes_.data(), other.bytes_.data(), 4) < 0;
} else {
- int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 16);
- if (ret != 0)
- return ret < 0;
+ return memcmp(bytes_.data(), other.bytes_.data(), 16) < 0;
}
+}
+
+bool operator<(const IPEndpoint& a, const IPEndpoint& b) {
+ if (a.address != b.address) {
+ return a.address < b.address;
+ }
+
return a.port < b.port;
}
@@ -275,4 +328,10 @@ std::ostream& operator<<(std::ostream& out, const IPEndpoint& endpoint) {
return out << ':' << std::dec << static_cast<int>(endpoint.port);
}
+std::string IPEndpoint::ToString() const {
+ std::ostringstream name;
+ name << this;
+ return name.str();
+}
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.h b/chromium/third_party/openscreen/src/platform/base/ip_address.h
index c4dff5b0c5d..c37054d9e83 100644
--- a/chromium/third_party/openscreen/src/platform/base/ip_address.h
+++ b/chromium/third_party/openscreen/src/platform/base/ip_address.h
@@ -22,37 +22,35 @@ class IPAddress {
kV6,
};
+ static const IPAddress kV4LoopbackAddress;
+ static const IPAddress kV6LoopbackAddress;
+
static constexpr size_t kV4Size = 4;
static constexpr size_t kV6Size = 16;
- // Parses a text representation of an IPv4 address (e.g. "192.168.0.1") or an
- // IPv6 address (e.g. "abcd::1234") and puts the result into |address|.
- static ErrorOr<IPAddress> Parse(const std::string& s);
-
IPAddress();
+
+ // |bytes| contains 4 octets for IPv4, or 8 hextets (16 bytes of big-endian
+ // shorts) for IPv6.
+ IPAddress(Version version, const uint8_t* bytes);
+
+ // IPv4 constructors (IPAddress from 4 octets).
explicit IPAddress(const std::array<uint8_t, 4>& bytes);
explicit IPAddress(const uint8_t (&b)[4]);
- explicit IPAddress(Version version, const uint8_t* b);
IPAddress(uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4);
- explicit IPAddress(const std::array<uint8_t, 16>& bytes);
- explicit IPAddress(const uint8_t (&b)[16]);
- IPAddress(uint8_t b1,
- uint8_t b2,
- uint8_t b3,
- uint8_t b4,
- uint8_t b5,
- uint8_t b6,
- uint8_t b7,
- uint8_t b8,
- uint8_t b9,
- uint8_t b10,
- uint8_t b11,
- uint8_t b12,
- uint8_t b13,
- uint8_t b14,
- uint8_t b15,
- uint8_t b16);
+ // IPv6 constructors (IPAddress from 8 hextets).
+ explicit IPAddress(const std::array<uint16_t, 8>& hextets);
+ explicit IPAddress(const uint16_t (&hextets)[8]);
+ IPAddress(uint16_t h1,
+ uint16_t h2,
+ uint16_t h3,
+ uint16_t h4,
+ uint16_t h5,
+ uint16_t h6,
+ uint16_t h7,
+ uint16_t h8);
+
IPAddress(const IPAddress& o) noexcept;
IPAddress(IPAddress&& o) noexcept;
~IPAddress() = default;
@@ -62,6 +60,11 @@ class IPAddress {
bool operator==(const IPAddress& o) const;
bool operator!=(const IPAddress& o) const;
+
+ bool operator<(const IPAddress& other) const;
+ bool operator>(const IPAddress& other) const { return other < *this; }
+ bool operator<=(const IPAddress& other) const { return !(other < *this); }
+ bool operator>=(const IPAddress& other) const { return !(*this < other); }
explicit operator bool() const;
Version version() const { return version_; }
@@ -78,12 +81,11 @@ class IPAddress {
// in order to avoid making multiple copies.
const uint8_t* bytes() const { return bytes_.data(); }
- private:
- static ErrorOr<IPAddress> ParseV4(const std::string& s);
- static ErrorOr<IPAddress> ParseV6(const std::string& s);
-
- friend class IPEndpointComparator;
+ // Parses a text representation of an IPv4 address (e.g. "192.168.0.1") or an
+ // IPv6 address (e.g. "abcd::1234").
+ static ErrorOr<IPAddress> Parse(const std::string& s);
+ private:
Version version_;
std::array<uint8_t, 16> bytes_;
};
@@ -91,16 +93,30 @@ class IPAddress {
struct IPEndpoint {
public:
IPAddress address;
- uint16_t port;
+ uint16_t port = 0;
+
+ explicit operator bool() const;
+
+ // Parses a text representation of an IPv4/IPv6 address and port (e.g.
+ // "192.168.0.1:8080" or "[abcd::1234]:8080").
+ static ErrorOr<IPEndpoint> Parse(const std::string& s);
+
+ std::string ToString() const;
};
bool operator==(const IPEndpoint& a, const IPEndpoint& b);
bool operator!=(const IPEndpoint& a, const IPEndpoint& b);
-class IPEndpointComparator {
- public:
- bool operator()(const IPEndpoint& a, const IPEndpoint& b) const;
-};
+bool operator<(const IPEndpoint& a, const IPEndpoint& b);
+inline bool operator>(const IPEndpoint& a, const IPEndpoint& b) {
+ return b < a;
+}
+inline bool operator<=(const IPEndpoint& a, const IPEndpoint& b) {
+ return !(b > a);
+}
+inline bool operator>=(const IPEndpoint& a, const IPEndpoint& b) {
+ return !(a > b);
+}
// Outputs a string of the form:
// 123.234.34.56
diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc
index 1bbbe8ec0c7..03a3c7537e5 100644
--- a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc
@@ -103,14 +103,16 @@ TEST(IPAddressTest, V4ParseFailures) {
TEST(IPAddressTest, V6Constructors) {
uint8_t bytes[16] = {};
- IPAddress address1(std::array<uint8_t, 16>{
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}});
+ IPAddress address1(std::array<uint16_t, 8>{
+ {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}});
address1.CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16}));
const uint8_t x[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
- IPAddress address2(x);
+ const uint16_t hextets[] = {0x0102, 0x0304, 0x0506, 0x0708,
+ 0x090a, 0x0b0c, 0x0d0e, 0x0f10};
+ IPAddress address2(hextets);
address2.CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray(x));
@@ -118,7 +120,8 @@ TEST(IPAddressTest, V6Constructors) {
address3.CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray(x));
- IPAddress address4(16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1);
+ IPAddress address4(0x100f, 0x0e0d, 0x0c0b, 0x0a09, 0x0807, 0x0605, 0x0403,
+ 0x0201);
address4.CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6,
5, 4, 3, 2, 1}));
@@ -135,11 +138,11 @@ TEST(IPAddressTest, V6ComparisonAndBoolean) {
EXPECT_FALSE(address1);
uint8_t x[] = {16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
- IPAddress address2(x);
+ IPAddress address2(IPAddress::Version::kV6, x);
EXPECT_NE(address1, address2);
EXPECT_TRUE(address2);
- IPAddress address3(x);
+ IPAddress address3(IPAddress::Version::kV6, x);
EXPECT_EQ(address2, address3);
EXPECT_TRUE(address3);
@@ -225,6 +228,12 @@ TEST(IPAddressTest, V6ParseFailures) {
<< "too few values should fail to parse";
EXPECT_FALSE(IPAddress::Parse("a:b:c:d:e:f:0:1:2:3:4:5:6:7:8:9:a"))
<< "too many values should fail to parse";
+ EXPECT_FALSE(IPAddress::Parse("1:2:3:4:5:6:7::8"))
+ << "too many values around double-colon should fail to parse";
+ EXPECT_FALSE(IPAddress::Parse("1:2:3:4:5:6:7:8::"))
+ << "too many values before double-colon should fail to parse";
+ EXPECT_FALSE(IPAddress::Parse("::1:2:3:4:5:6:7:8"))
+ << "too many values after double-colon should fail to parse";
EXPECT_FALSE(IPAddress::Parse("abcd1::dbca"))
<< "value > 0xffff should fail to parse";
EXPECT_FALSE(IPAddress::Parse("::abcd::dbca"))
@@ -248,4 +257,64 @@ TEST(IPAddressTest, V6ParseThreeDigitValue) {
0x01, 0x23}));
}
+TEST(IPAddressTest, IPEndpointBoolOperator) {
+ IPEndpoint endpoint;
+ if (endpoint) {
+ FAIL();
+ }
+
+ endpoint = IPEndpoint{{192, 168, 0, 1}, 80};
+ if (!endpoint) {
+ FAIL();
+ }
+
+ endpoint = IPEndpoint{{192, 168, 0, 1}, 0};
+ if (!endpoint) {
+ FAIL();
+ }
+
+ endpoint = IPEndpoint{{}, 80};
+ if (!endpoint) {
+ FAIL();
+ }
+}
+
+TEST(IPAddressTest, IPEndpointParse) {
+ IPEndpoint expected{IPAddress(std::array<uint8_t, 4>{{1, 2, 3, 4}}), 5678};
+ ErrorOr<IPEndpoint> result = IPEndpoint::Parse("1.2.3.4:5678");
+ ASSERT_TRUE(result.is_value()) << result.error();
+ EXPECT_EQ(expected, result.value());
+
+ expected = IPEndpoint{
+ IPAddress(std::array<uint16_t, 8>{{0xabcd, 0, 0, 0, 0, 0, 0, 1}}), 99};
+ result = IPEndpoint::Parse("[abcd::1]:99");
+ ASSERT_TRUE(result.is_value()) << result.error();
+ EXPECT_EQ(expected, result.value());
+
+ expected = IPEndpoint{
+ IPAddress(std::array<uint16_t, 8>{{0, 0, 0, 0, 0, 0, 0, 0}}), 5791};
+ result = IPEndpoint::Parse("[::]:5791");
+ ASSERT_TRUE(result.is_value()) << result.error();
+ EXPECT_EQ(expected, result.value());
+
+ EXPECT_FALSE(IPEndpoint::Parse("")); // Empty string.
+ EXPECT_FALSE(IPEndpoint::Parse("beef")); // Random word.
+ EXPECT_FALSE(IPEndpoint::Parse("localhost:99")); // We don't do DNS.
+ EXPECT_FALSE(IPEndpoint::Parse(":80")); // Missing address.
+ EXPECT_FALSE(IPEndpoint::Parse("[]:22")); // Missing address.
+ EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4")); // Missing port after IPv4.
+ EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]")); // Missing port after IPv6.
+ EXPECT_FALSE(IPEndpoint::Parse("abcd::1:8080")); // Missing square brackets.
+
+ // No extra whitespace is allowed.
+ EXPECT_FALSE(IPEndpoint::Parse(" 1.2.3.4:5678"));
+ EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4 :5678"));
+ EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4: 5678"));
+ EXPECT_FALSE(IPEndpoint::Parse("1.2.3.4:5678 "));
+ EXPECT_FALSE(IPEndpoint::Parse(" [abcd::1]:99"));
+ EXPECT_FALSE(IPEndpoint::Parse("[abcd::1] :99"));
+ EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]: 99"));
+ EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]:99 "));
+}
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/location.cc b/chromium/third_party/openscreen/src/platform/base/location.cc
index f9ec010815a..ad5a3abd178 100644
--- a/chromium/third_party/openscreen/src/platform/base/location.cc
+++ b/chromium/third_party/openscreen/src/platform/base/location.cc
@@ -4,7 +4,8 @@
#include "platform/base/location.h"
-#include "absl/strings/str_cat.h"
+#include <sstream>
+
#include "platform/base/macros.h"
namespace openscreen {
@@ -24,7 +25,9 @@ std::string Location::ToString() const {
return "pc:NULL";
}
- return absl::StrCat("pc:0x", absl::Hex(program_counter_));
+ std::ostringstream oss;
+ oss << "pc:" << program_counter_;
+ return oss.str();
}
#if defined(__GNUC__)
diff --git a/chromium/third_party/openscreen/src/platform/base/socket_state.h b/chromium/third_party/openscreen/src/platform/base/socket_state.h
index bd82dd292d7..7672c4f5086 100644
--- a/chromium/third_party/openscreen/src/platform/base/socket_state.h
+++ b/chromium/third_party/openscreen/src/platform/base/socket_state.h
@@ -10,7 +10,6 @@
#include <string>
namespace openscreen {
-namespace platform {
// SocketState should be used by TCP and TLS sockets for indicating
// current state. NOTE: socket state transitions should only happen in
@@ -30,7 +29,6 @@ enum class SocketState {
kClosed
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_SOCKET_STATE_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h b/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h
index 12a0cd269b0..dcaebe0ae1d 100644
--- a/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h
+++ b/chromium/third_party/openscreen/src/platform/base/tls_connect_options.h
@@ -8,7 +8,6 @@
#include "platform/base/macros.h"
namespace openscreen {
-namespace platform {
struct TlsConnectOptions {
// This option allows TLS connections to devices without
@@ -17,7 +16,6 @@ struct TlsConnectOptions {
bool unsafely_skip_certificate_validation;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_TLS_CONNECT_OPTIONS_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc b/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc
index 9493f597f86..7f7af41e37e 100644
--- a/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc
+++ b/chromium/third_party/openscreen/src/platform/base/tls_credentials.cc
@@ -5,7 +5,6 @@
#include "platform/base/tls_credentials.h"
namespace openscreen {
-namespace platform {
TlsCredentials::TlsCredentials() = default;
@@ -18,5 +17,4 @@ TlsCredentials::TlsCredentials(std::vector<uint8_t> der_rsa_private_key,
TlsCredentials::~TlsCredentials() = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/tls_credentials.h b/chromium/third_party/openscreen/src/platform/base/tls_credentials.h
index 8c5162dc10f..285e27057cc 100644
--- a/chromium/third_party/openscreen/src/platform/base/tls_credentials.h
+++ b/chromium/third_party/openscreen/src/platform/base/tls_credentials.h
@@ -10,7 +10,6 @@
#include <vector>
namespace openscreen {
-namespace platform {
struct TlsCredentials {
TlsCredentials();
@@ -29,7 +28,6 @@ struct TlsCredentials {
std::vector<uint8_t> der_x509_cert;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_TLS_CREDENTIALS_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h b/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h
index 32cab160ba2..da47b661b58 100644
--- a/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h
+++ b/chromium/third_party/openscreen/src/platform/base/tls_listen_options.h
@@ -10,13 +10,11 @@
#include "platform/base/macros.h"
namespace openscreen {
-namespace platform {
struct TlsListenOptions {
uint32_t backlog_size;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_TLS_LISTEN_OPTIONS_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc
index dd7ac6c3443..16a893a0dd2 100644
--- a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc
+++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.cc
@@ -4,29 +4,69 @@
#include "platform/base/trace_logging_activation.h"
+#include <atomic>
#include <cassert>
+#include <thread>
namespace openscreen {
-namespace platform {
namespace {
-TraceLoggingPlatform* g_current_destination = nullptr;
-} // namespace
-TraceLoggingPlatform* GetTracingDestination() {
- return g_current_destination;
+// If tracing is active, this is a valid pointer to an object that implements
+// the TraceLoggingPlatform interface. If tracing is not active, this is
+// nullptr.
+std::atomic<TraceLoggingPlatform*> g_current_destination{};
+
+// The count of threads currently calling into the current TraceLoggingPlatform.
+std::atomic<int> g_use_count{};
+
+inline TraceLoggingPlatform* PinCurrentDestination() {
+ // NOTE: It's important to increment the global use count *before* loading the
+ // pointer, to ensure the referent is pinned-down (i.e., any thread executing
+ // StopTracing() stays blocked) until CurrentTracingDestination's destructor
+ // calls UnpinCurrentDestination().
+ g_use_count.fetch_add(1);
+ return g_current_destination.load(std::memory_order_relaxed);
}
+inline void UnpinCurrentDestination() {
+ g_use_count.fetch_sub(1);
+}
+
+} // namespace
+
void StartTracing(TraceLoggingPlatform* destination) {
- // TODO(crbug.com/openscreen/85): Need to revisit this to ensure thread-safety
- // around the sequencing of starting and stopping tracing.
- assert(!g_current_destination);
- g_current_destination = destination;
+ assert(destination);
+ auto* const old_destination = g_current_destination.exchange(destination);
+ (void)old_destination; // Prevent "unused variable" compiler warnings.
+ assert(old_destination == nullptr || old_destination == destination);
}
void StopTracing() {
- g_current_destination = nullptr;
+ auto* const old_destination = g_current_destination.exchange(nullptr);
+ if (!old_destination) {
+ return; // Already stopped.
+ }
+
+ // Block the current thread until the global use count goes to zero. At that
+ // point, there can no longer be any dangling references. Theoretically, this
+ // loop may never terminate; but in practice, that should never happen. If it
+ // did happen, that would mean one or more CPU cores are continuously spending
+ // most of their time executing the TraceLoggingPlatform methods, yet those
+ // methods are supposed to be super-cheap and take near-zero time to execute!
+ int iters = 0;
+ while (g_use_count.load(std::memory_order_relaxed) != 0) {
+ assert(iters < 1024);
+ std::this_thread::yield();
+ ++iters;
+ }
+}
+
+CurrentTracingDestination::CurrentTracingDestination()
+ : destination_(PinCurrentDestination()) {}
+
+CurrentTracingDestination::~CurrentTracingDestination() {
+ UnpinCurrentDestination();
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h
index 05be734387d..d53d9498fdc 100644
--- a/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h
+++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_activation.h
@@ -6,20 +6,49 @@
#define PLATFORM_BASE_TRACE_LOGGING_ACTIVATION_H_
namespace openscreen {
-namespace platform {
class TraceLoggingPlatform;
// Start or Stop trace logging. It is illegal to call StartTracing() a second
// time without having called StopTracing() to stop the prior tracing session.
+//
+// Note that StopTracing() may block until all threads have returned from any
+// in-progress calls into the TraceLoggingPlatform's methods.
void StartTracing(TraceLoggingPlatform* destination);
void StopTracing();
-// If tracing is active, returns the current destination. Otherwise, returns
-// nullptr.
-TraceLoggingPlatform* GetTracingDestination();
+// An immutable, non-copyable and non-movable smart pointer that references the
+// current trace logging destination. If tracing was active when this class was
+// intantiated, the pointer is valid for the life of the instance, and can be
+// used to directly invoke the methods of the TraceLoggingPlatform API. If
+// tracing was not active when this class was intantiated, the pointer is null
+// for the life of the instance and must not be dereferenced.
+//
+// An instance should be short-lived, as a platform's call to StopTracing() will
+// be blocked until there are no instances remaining.
+//
+// NOTE: This is generally not used directly, but instead via the
+// util/trace_logging macros.
+class CurrentTracingDestination {
+ public:
+ CurrentTracingDestination();
+ ~CurrentTracingDestination();
+
+ explicit operator bool() const noexcept { return !!destination_; }
+ TraceLoggingPlatform* operator->() const noexcept { return destination_; }
+
+ private:
+ CurrentTracingDestination(const CurrentTracingDestination&) = delete;
+ CurrentTracingDestination(CurrentTracingDestination&&) = delete;
+ CurrentTracingDestination& operator=(const CurrentTracingDestination&) =
+ delete;
+ CurrentTracingDestination& operator=(CurrentTracingDestination&&) = delete;
+
+ // The destination at the time this class was constructed, and is valid for
+ // the lifetime of this class. This is nullptr if tracing was inactive.
+ TraceLoggingPlatform* const destination_;
+};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_TRACE_LOGGING_ACTIVATION_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h b/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h
index ce33edb42da..0ee9dfb18a8 100644
--- a/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h
+++ b/chromium/third_party/openscreen/src/platform/base/trace_logging_types.h
@@ -5,10 +5,11 @@
#ifndef PLATFORM_BASE_TRACE_LOGGING_TYPES_H_
#define PLATFORM_BASE_TRACE_LOGGING_TYPES_H_
-#include "absl/types/optional.h"
+#include <stdint.h>
+
+#include <limits>
namespace openscreen {
-namespace platform {
// Define TraceId type here since other TraceLogging files import it.
using TraceId = uint64_t;
@@ -51,18 +52,17 @@ inline bool operator!=(const TraceIdHierarchy& lhs,
// BitFlags to represent the supported tracing categories.
// NOTE: These are currently placeholder values and later changes should feel
// free to edit them.
-// TODO(rwkeane): Rename SSL to either Ssl or kSsl
struct TraceCategory {
enum Value : uint64_t {
- Any = std::numeric_limits<uint64_t>::max(),
- mDNS = 0x01 << 0,
- Quic = 0x01 << 1,
- SSL = 0x01 << 2,
- Presentation = 0x01 << 3,
+ kAny = std::numeric_limits<uint64_t>::max(),
+ kMdns = 0x01 << 0,
+ kQuic = 0x01 << 1,
+ kSsl = 0x01 << 2,
+ kPresentation = 0x01 << 3,
+ kStandaloneReceiver = 0x01 << 4
};
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_TRACE_LOGGING_TYPES_H_
diff --git a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc
index 3be1c783268..96f1fa3dde0 100644
--- a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc
+++ b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.cc
@@ -7,15 +7,32 @@
namespace openscreen {
std::ostream& operator<<(std::ostream& os,
- const platform::TrivialClockTraits::duration& d) {
- constexpr char kUnits[] = "\u03BCs"; // Greek Mu + "s"
+ const TrivialClockTraits::duration& d) {
+ constexpr char kUnits[] = " \u03BCs"; // Greek Mu + "s"
return os << d.count() << kUnits;
}
std::ostream& operator<<(std::ostream& os,
- const platform::TrivialClockTraits::time_point& tp) {
- constexpr char kUnits[] = "\u03BCs-ticks"; // Greek Mu + "s-ticks"
+ const TrivialClockTraits::time_point& tp) {
+ constexpr char kUnits[] = " \u03BCs-ticks"; // Greek Mu + "s-ticks"
return os << tp.time_since_epoch().count() << kUnits;
}
+std::ostream& operator<<(std::ostream& out, const std::chrono::hours& hrs) {
+ return (out << hrs.count() << " hours");
+}
+
+std::ostream& operator<<(std::ostream& out, const std::chrono::minutes& mins) {
+ return (out << mins.count() << " minutes");
+}
+
+std::ostream& operator<<(std::ostream& out, const std::chrono::seconds& secs) {
+ return (out << secs.count() << " seconds");
+}
+
+std::ostream& operator<<(std::ostream& out,
+ const std::chrono::milliseconds& millis) {
+ return (out << millis.count() << " ms");
+}
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h
index b5392c43400..426a2b8a9ce 100644
--- a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h
+++ b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h
@@ -9,7 +9,6 @@
#include <ostream>
namespace openscreen {
-namespace platform {
// The Open Screen monotonic clock traits description, providing all the C++14
// requirements of a TrivialClock, for use with STL <chrono>.
@@ -44,16 +43,22 @@ class TrivialClockTraits {
// &Clock::now versus something else for testing).
using ClockNowFunctionPtr = TrivialClockTraits::time_point (*)();
-} // namespace platform
-
// Logging convenience for durations. Outputs a string of the form "123µs".
std::ostream& operator<<(std::ostream& os,
- const platform::TrivialClockTraits::duration& d);
+ const TrivialClockTraits::duration& d);
// Logging convenience for time points. Outputs a string of the form
// "123µs-ticks".
std::ostream& operator<<(std::ostream& os,
- const platform::TrivialClockTraits::time_point& tp);
+ const TrivialClockTraits::time_point& tp);
+
+// Logging (and gtest pretty-printing) for several commonly-used chrono types.
+std::ostream& operator<<(std::ostream& out, const std::chrono::hours&);
+std::ostream& operator<<(std::ostream& out, const std::chrono::minutes&);
+std::ostream& operator<<(std::ostream& out, const std::chrono::seconds&);
+std::ostream& operator<<(std::ostream& out, const std::chrono::milliseconds&);
+// Note: The ostream output operator for std::chrono::microseconds is handled by
+// the one for TrivialClockTraits::duration above since they are the same type.
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc
index d8ad75e421f..6cbb0da2f83 100644
--- a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc
+++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc
@@ -7,14 +7,9 @@
#include <cassert>
namespace openscreen {
-namespace platform {
UdpPacket::UdpPacket() : std::vector<uint8_t>() {}
-UdpPacket::UdpPacket(size_type size) : std::vector<uint8_t>(size) {
- assert(size <= kUdpMaxPacketSize);
-}
-
UdpPacket::UdpPacket(size_type size, uint8_t fill_value)
: std::vector<uint8_t>(size, fill_value) {
assert(size <= kUdpMaxPacketSize);
@@ -31,5 +26,4 @@ UdpPacket::~UdpPacket() = default;
UdpPacket& UdpPacket::operator=(UdpPacket&& other) = default;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.h b/chromium/third_party/openscreen/src/platform/base/udp_packet.h
index 479cee594bf..a8fcce045a1 100644
--- a/chromium/third_party/openscreen/src/platform/base/udp_packet.h
+++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.h
@@ -7,14 +7,12 @@
#include <stdint.h>
-#include <algorithm>
#include <utility>
#include <vector>
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
class UdpSocket;
@@ -25,13 +23,9 @@ class UdpPacket : public std::vector<uint8_t> {
public:
// C++14 vector constructors, sans Allocator foo, and no copy ctor.
UdpPacket();
- explicit UdpPacket(size_type size);
- explicit UdpPacket(size_type size, uint8_t fill_value);
+ explicit UdpPacket(size_type size, uint8_t fill_value = {});
template <typename InputIt>
- UdpPacket(InputIt first, InputIt last)
- : UdpPacket(std::distance(first, last)) {
- std::copy(first, last, begin());
- }
+ UdpPacket(InputIt first, InputIt last) : std::vector<uint8_t>(first, last) {}
UdpPacket(UdpPacket&& other);
UdpPacket(std::initializer_list<uint8_t> init);
@@ -60,7 +54,6 @@ class UdpPacket : public std::vector<uint8_t> {
OSP_DISALLOW_COPY_AND_ASSIGN(UdpPacket);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_BASE_UDP_PACKET_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/logging.h b/chromium/third_party/openscreen/src/platform/impl/logging.h
index 033691f82b8..85d6abcd319 100644
--- a/chromium/third_party/openscreen/src/platform/impl/logging.h
+++ b/chromium/third_party/openscreen/src/platform/impl/logging.h
@@ -8,7 +8,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
// Direct all logging output to a named FIFO having the given |filename|. If the
// file does not exist, an attempt is made to auto-create it. If unsuccessful,
@@ -20,7 +19,6 @@ void SetLogFifoOrDie(const char* filename);
// default.
void SetLogLevel(LogLevel level);
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_LOGGING_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc b/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc
index ab921168468..412876ce9bc 100644
--- a/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/logging_posix.cc
@@ -16,7 +16,6 @@
#include "util/trace_logging.h"
namespace openscreen {
-namespace platform {
namespace {
int g_log_fd = STDERR_FILENO;
@@ -80,22 +79,22 @@ void SetLogLevel(LogLevel level) {
g_log_level = level;
}
-bool IsLoggingOn(LogLevel level, absl::string_view file) {
+bool IsLoggingOn(LogLevel level, const char* file) {
// Possible future enhancement: Use glob patterns passed on the command-line
// to use a different logging level for certain files, like in Chromium.
return level >= g_log_level;
}
void LogWithLevel(LogLevel level,
- absl::string_view file,
+ const char* file,
int line,
- absl::string_view msg) {
+ std::stringstream message) {
if (level < g_log_level)
return;
std::stringstream ss;
ss << "[" << level << ":" << file << "(" << line << "):T" << std::hex
- << TRACE_CURRENT_ID << "] " << msg << '\n';
+ << TRACE_CURRENT_ID << "] " << message.rdbuf() << '\n';
const auto ss_str = ss.str();
const auto bytes_written = write(g_log_fd, ss_str.c_str(), ss_str.size());
OSP_DCHECK(bytes_written);
@@ -109,5 +108,4 @@ void Break() {
#endif
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc
new file mode 100644
index 00000000000..17e240f2e45
--- /dev/null
+++ b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc
@@ -0,0 +1,44 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "platform/impl/network_interface.h"
+
+namespace openscreen {
+
+std::vector<InterfaceInfo> GetNetworkInterfaces() {
+ std::vector<InterfaceInfo> interfaces = GetAllInterfaces();
+
+ const auto new_end = std::remove_if(
+ interfaces.begin(), interfaces.end(), [](const InterfaceInfo& info) {
+ return info.type != InterfaceInfo::Type::kEthernet &&
+ info.type != InterfaceInfo::Type::kWifi &&
+ info.type != InterfaceInfo::Type::kOther;
+ });
+ interfaces.erase(new_end, interfaces.end());
+
+ return interfaces;
+}
+
+// Returns an InterfaceInfo associated with the system's loopback interface.
+absl::optional<InterfaceInfo> GetLoopbackInterfaceForTesting() {
+ const std::vector<InterfaceInfo> interfaces = GetAllInterfaces();
+ auto it = std::find_if(
+ interfaces.begin(), interfaces.end(), [](const InterfaceInfo& info) {
+ return info.type == InterfaceInfo::Type::kLoopback &&
+ std::find_if(
+ info.addresses.begin(), info.addresses.end(),
+ [](const IPSubnet& subnet) {
+ return subnet.address == IPAddress::kV4LoopbackAddress ||
+ subnet.address == IPAddress::kV6LoopbackAddress;
+ }) != info.addresses.end();
+ });
+
+ if (it == interfaces.end()) {
+ return absl::nullopt;
+ } else {
+ return *it;
+ }
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface.h b/chromium/third_party/openscreen/src/platform/impl/network_interface.h
new file mode 100644
index 00000000000..8682cffb21a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/platform/impl/network_interface.h
@@ -0,0 +1,26 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef PLATFORM_IMPL_NETWORK_INTERFACE_H_
+#define PLATFORM_IMPL_NETWORK_INTERFACE_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "platform/base/interface_info.h"
+
+namespace openscreen {
+
+// The below functions are responsible for returning the network interfaces
+// provided of the current machine. GetAllInterfaces() returns all interfaces,
+// real or virtual. GetLoopbackInterfaceForTesting() returns one such interface
+// which is associated with the machine's loopback interface, while
+// GetNetworkInterfaces() returns all non-loopback interfaces.
+std::vector<InterfaceInfo> GetAllInterfaces();
+absl::optional<InterfaceInfo> GetLoopbackInterfaceForTesting();
+std::vector<InterfaceInfo> GetNetworkInterfaces();
+
+} // namespace openscreen
+
+#endif // PLATFORM_IMPL_NETWORK_INTERFACE_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc
index 5f432265662..d58e5d9ce66 100644
--- a/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/network_interface_linux.cc
@@ -25,11 +25,11 @@
#include "absl/types/optional.h"
#include "platform/api/network_interface.h"
#include "platform/base/ip_address.h"
+#include "platform/impl/network_interface.h"
#include "platform/impl/scoped_pipe.h"
#include "util/logging.h"
namespace openscreen {
-namespace platform {
namespace {
constexpr int kNetlinkRecvmsgBufSize = 8192;
@@ -74,8 +74,9 @@ InterfaceInfo::Type GetInterfaceType(const std::string& ifname) {
wr.ifr_name[IFNAMSIZ - 1] = 0;
strncpy(ifr.ifr_name, ifname.c_str(), IFNAMSIZ - 1);
ifr.ifr_data = &ecmd;
- if (ioctl(s.get(), SIOCETHTOOL, &ifr) != -1)
+ if (ioctl(s.get(), SIOCETHTOOL, &ifr) != -1) {
return InterfaceInfo::Type::kEthernet;
+ }
return InterfaceInfo::Type::kOther;
}
@@ -86,6 +87,7 @@ InterfaceInfo::Type GetInterfaceType(const std::string& ifname) {
// pointed to by |rta|.
void GetInterfaceAttributes(struct rtattr* rta,
unsigned int attrlen,
+ bool is_loopback,
InterfaceInfo* info) {
for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
if (rta->rta_type == IFLA_IFNAME) {
@@ -93,12 +95,16 @@ void GetInterfaceAttributes(struct rtattr* rta,
GetInterfaceName(reinterpret_cast<const char*>(RTA_DATA(rta)));
} else if (rta->rta_type == IFLA_ADDRESS) {
OSP_CHECK_EQ(sizeof(info->hardware_address), RTA_PAYLOAD(rta));
- std::memcpy(info->hardware_address, RTA_DATA(rta),
+ std::memcpy(info->hardware_address.data(), RTA_DATA(rta),
sizeof(info->hardware_address));
}
}
- info->type = GetInterfaceType(info->name);
+ if (is_loopback) {
+ info->type = InterfaceInfo::Type::kLoopback;
+ } else {
+ info->type = GetInterfaceType(info->name);
+ }
}
// Reads the IPv4 or IPv6 address that comes from an RTM_NEWADDR message and
@@ -220,16 +226,17 @@ std::vector<InterfaceInfo> GetLinkInfo() {
struct ifinfomsg* interface_info =
static_cast<struct ifinfomsg*>(NLMSG_DATA(netlink_header));
- // Only process non-loopback interfaces which are active (up).
- if ((interface_info->ifi_flags & IFF_LOOPBACK) ||
- ((interface_info->ifi_flags & IFF_UP) == 0)) {
+ // Only process interfaces which are active (up).
+ if (!(interface_info->ifi_flags & IFF_UP)) {
continue;
}
+
info_list.emplace_back();
InterfaceInfo& info = info_list.back();
info.index = interface_info->ifi_index;
GetInterfaceAttributes(IFLA_RTA(interface_info),
- IFLA_PAYLOAD(netlink_header), &info);
+ IFLA_PAYLOAD(netlink_header),
+ interface_info->ifi_flags & IFF_LOOPBACK, &info);
}
}
}
@@ -351,11 +358,10 @@ void PopulateSubnetsOrClearList(std::vector<InterfaceInfo>* info_list) {
} // namespace
-std::vector<InterfaceInfo> GetNetworkInterfaces() {
+std::vector<InterfaceInfo> GetAllInterfaces() {
std::vector<InterfaceInfo> interfaces = GetLinkInfo();
PopulateSubnetsOrClearList(&interfaces);
return interfaces;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc
index 138d4918b57..da219bd7c32 100644
--- a/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/network_interface_mac.cc
@@ -19,11 +19,11 @@
#include "platform/api/network_interface.h"
#include "platform/base/ip_address.h"
+#include "platform/impl/network_interface.h"
#include "platform/impl/scoped_pipe.h"
#include "util/logging.h"
namespace openscreen {
-namespace platform {
namespace {
@@ -68,13 +68,12 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) {
// Socket used for querying interface media types.
const ScopedFd ioctl_socket(socket(AF_INET6, SOCK_DGRAM, 0));
- // Walk the |interfaces| linked list, creating the hierarchial structure.
+ // Walk the |interfaces| linked list, creating the hierarchical structure.
std::vector<InterfaceInfo> results;
for (ifaddrs* cur = interfaces; cur; cur = cur->ifa_next) {
- // Skip: 1) loopback interfaces, 2) interfaces that are down, 3) interfaces
- // with no address configured.
- if ((IFF_LOOPBACK & cur->ifa_flags) || !(IFF_RUNNING & cur->ifa_flags) ||
- !cur->ifa_addr) {
+ // Skip: 1) interfaces that are down, 2) interfaces with no address
+ // configured.
+ if (!(IFF_RUNNING & cur->ifa_flags) || !cur->ifa_addr) {
continue;
}
@@ -107,6 +106,9 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) {
if (ifmr.ifm_current & IFM_ETHER) {
type = InterfaceInfo::Type::kEthernet;
}
+ if (cur->ifa_flags & IFF_LOOPBACK) {
+ type = InterfaceInfo::Type::kLoopback;
+ }
// Start with an unknown hardware ethernet address, which should be
// updated as the linked list is walked.
@@ -163,7 +165,7 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) {
} // namespace
-std::vector<InterfaceInfo> GetNetworkInterfaces() {
+std::vector<InterfaceInfo> GetAllInterfaces() {
std::vector<InterfaceInfo> results;
ifaddrs* interfaces;
if (getifaddrs(&interfaces) == 0) {
@@ -173,5 +175,4 @@ std::vector<InterfaceInfo> GetNetworkInterfaces() {
return results;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc
index f31f19aad79..d1415258bd7 100644
--- a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.cc
@@ -4,48 +4,23 @@
#include "platform/impl/platform_client_posix.h"
-#include <mutex>
+#include <functional>
+#include <vector>
#include "platform/impl/udp_socket_reader_posix.h"
namespace openscreen {
-namespace platform {
// static
PlatformClientPosix* PlatformClientPosix::instance_ = nullptr;
-PlatformClientPosix::PlatformClientPosix(
- Clock::duration networking_operation_timeout,
- Clock::duration networking_loop_interval)
- : networking_loop_(networking_operations(),
- networking_operation_timeout,
- networking_loop_interval),
- owned_task_runner_(Clock::now),
- networking_loop_thread_(&OperationLoop::RunUntilStopped,
- &networking_loop_),
- task_runner_thread_(&TaskRunnerImpl::RunUntilStopped,
- &owned_task_runner_.value()) {}
-
-PlatformClientPosix::PlatformClientPosix(
- Clock::duration networking_operation_timeout,
- Clock::duration networking_loop_interval,
- std::unique_ptr<TaskRunner> task_runner)
- : networking_loop_(networking_operations(),
- networking_operation_timeout,
- networking_loop_interval),
- caller_provided_task_runner_(std::move(task_runner)),
- networking_loop_thread_(&OperationLoop::RunUntilStopped,
- &networking_loop_) {}
-
-PlatformClientPosix::~PlatformClientPosix() {
- networking_loop_.RequestStopSoon();
- networking_loop_thread_.join();
- if (owned_task_runner_.has_value()) {
- owned_task_runner_.value().RequestStopSoon();
- }
- if (task_runner_thread_.joinable()) {
- task_runner_thread_.join();
- }
+// static
+void PlatformClientPosix::Create(Clock::duration networking_operation_timeout,
+ Clock::duration networking_loop_interval,
+ std::unique_ptr<TaskRunnerImpl> task_runner) {
+ SetInstance(new PlatformClientPosix(networking_operation_timeout,
+ networking_loop_interval,
+ std::move(task_runner)));
}
// static
@@ -56,21 +31,19 @@ void PlatformClientPosix::Create(Clock::duration networking_operation_timeout,
}
// static
-void PlatformClientPosix::SetInstance(PlatformClientPosix* instance) {
- OSP_DCHECK(!instance_);
- instance_ = instance;
-}
-
-// static
void PlatformClientPosix::ShutDown() {
OSP_DCHECK(instance_);
delete instance_;
instance_ = nullptr;
}
-TaskRunner* PlatformClientPosix::GetTaskRunner() {
- return owned_task_runner_.has_value() ? &owned_task_runner_.value()
- : caller_provided_task_runner_.get();
+TlsDataRouterPosix* PlatformClientPosix::tls_data_router() {
+ std::call_once(tls_data_router_initialization_, [this]() {
+ tls_data_router_ =
+ std::make_unique<TlsDataRouterPosix>(socket_handle_waiter());
+ tls_data_router_created_.store(true);
+ });
+ return tls_data_router_.get();
}
UdpSocketReaderPosix* PlatformClientPosix::udp_socket_reader() {
@@ -81,18 +54,56 @@ UdpSocketReaderPosix* PlatformClientPosix::udp_socket_reader() {
return udp_socket_reader_.get();
}
-TlsDataRouterPosix* PlatformClientPosix::tls_data_router() {
- std::call_once(tls_data_router_initialization_, [this]() {
- tls_data_router_ =
- std::make_unique<TlsDataRouterPosix>(socket_handle_waiter());
- tls_data_router_created_.store(true);
- });
- return tls_data_router_.get();
+TaskRunner* PlatformClientPosix::GetTaskRunner() {
+ return task_runner_.get();
}
+PlatformClientPosix::~PlatformClientPosix() {
+ OSP_DVLOG << "Shutting down the Task Runner...";
+ task_runner_->RequestStopSoon();
+ if (task_runner_thread_ && task_runner_thread_->joinable()) {
+ task_runner_thread_->join();
+ OSP_DVLOG << "\tTask Runner shutdown complete!";
+ }
+
+ OSP_DVLOG << "Shutting down network operations...";
+ networking_loop_.RequestStopSoon();
+ networking_loop_thread_.join();
+ OSP_DVLOG << "\tNetwork operation shutdown complete!";
+}
+
+// static
+void PlatformClientPosix::SetInstance(PlatformClientPosix* instance) {
+ OSP_DCHECK(!instance_);
+ instance_ = instance;
+}
+
+PlatformClientPosix::PlatformClientPosix(
+ Clock::duration networking_operation_timeout,
+ Clock::duration networking_loop_interval)
+ : networking_loop_(networking_operations(),
+ networking_operation_timeout,
+ networking_loop_interval),
+ task_runner_(new TaskRunnerImpl(Clock::now)),
+ networking_loop_thread_(&OperationLoop::RunUntilStopped,
+ &networking_loop_),
+ task_runner_thread_(
+ std::thread(&TaskRunnerImpl::RunUntilStopped, task_runner_.get())) {}
+
+PlatformClientPosix::PlatformClientPosix(
+ Clock::duration networking_operation_timeout,
+ Clock::duration networking_loop_interval,
+ std::unique_ptr<TaskRunnerImpl> task_runner)
+ : networking_loop_(networking_operations(),
+ networking_operation_timeout,
+ networking_loop_interval),
+ task_runner_(std::move(task_runner)),
+ networking_loop_thread_(&OperationLoop::RunUntilStopped,
+ &networking_loop_) {}
+
SocketHandleWaiterPosix* PlatformClientPosix::socket_handle_waiter() {
std::call_once(waiter_initialization_, [this]() {
- waiter_ = std::make_unique<SocketHandleWaiterPosix>();
+ waiter_ = std::make_unique<SocketHandleWaiterPosix>(&Clock::now);
waiter_created_.store(true);
});
return waiter_.get();
@@ -107,23 +118,11 @@ void PlatformClientPosix::PerformSocketHandleWaiterActions(
socket_handle_waiter()->ProcessHandles(timeout);
}
-void PlatformClientPosix::PerformTlsDataRouterActions(Clock::duration timeout) {
- if (!tls_data_router_created_.load()) {
- return;
- }
-
- tls_data_router()->PerformNetworkingOperations(timeout);
-}
-
std::vector<std::function<void(Clock::duration)>>
PlatformClientPosix::networking_operations() {
return {[this](Clock::duration timeout) {
- PerformSocketHandleWaiterActions(timeout);
- },
- [this](Clock::duration timeout) {
- PerformTlsDataRouterActions(timeout);
- }};
+ PerformSocketHandleWaiterActions(timeout);
+ }};
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h
index 57b0cf75c02..6bda424b168 100644
--- a/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/platform_client_posix.h
@@ -10,6 +10,7 @@
#include <mutex>
#include <thread>
+#include "absl/types/optional.h"
#include "platform/api/time.h"
#include "platform/base/macros.h"
#include "platform/impl/socket_handle_waiter_posix.h"
@@ -18,43 +19,57 @@
#include "util/operation_loop.h"
namespace openscreen {
-namespace platform {
class UdpSocketReaderPosix;
-// A PlatformClientPosix is an access point for all singletons in a standalone
-// application. The static SetInstance method is to be called before library use
-// begins, and the ShutDown() method should be called to deallocate the platform
-// library's global singletons (for example to save memory when libcast is not
-// in use).
+// Creates and provides access to singletons used by the default platform
+// implementation. An instance must be created before an application uses any
+// public modules in the Open Screen Library.
+//
+// ShutDown() should be called to destroy the PlatformClientPosix's singletons
+// and TaskRunner to save resources when library APIs are not in use.
+// ShutDown() calls TaskRunner::RunUntilStopped() to run any pending cleanup
+// tasks.
+//
+// Create and ShutDown must be called in the same sequence.
+//
+// FIXME: Remove Create and Shutdown and use the ctor/dtor directly.
class PlatformClientPosix {
public:
- // This method is expected to be called before the library is used.
- // The networking_loop_interval parameter here represents the minimum amount
- // of time that should pass between iterations of the loop used to handle
- // networking operations. Higher values will result in less time being spent
- // on these operations, but also potentially less performant networking
- // operations. The networking_operation_timeout parameter refers to how much
- // time may be spent on a single networking operation type.
- // NOTE: This method is NOT thread safe and should only be called from the
- // embedder thread.
+ // Initializes the platform implementation.
+ //
+ // |networking_loop_interval| sets the minimum amount of time that should pass
+ // between iterations of the loop used to handle networking operations. Higher
+ // values will result in less time being spent on these operations, but also
+ // potentially less performant networking operations.
+ //
+ // |networking_operation_timeout| sets how much time may be spent on a
+ // single networking operation type.
+ //
+ // |task_runner| is a client-provided TaskRunner implementation.
+ static void Create(Clock::duration networking_operation_timeout,
+ Clock::duration networking_loop_interval,
+ std::unique_ptr<TaskRunnerImpl> task_runner);
+
+ // Initializes the platform implementation and creates a new TaskRunner (which
+ // starts a new thread).
static void Create(Clock::duration networking_operation_timeout,
Clock::duration networking_loop_interval);
// Shuts down and deletes the PlatformClient instance currently stored as a
// singleton. This method is expected to be called before program exit. After
// calling this method, if the client wishes to continue using the platform
- // library, a new singleton must be created.
- // NOTE: This method is NOT thread safe and should only be called from the
- // embedder thread.
+ // library, Create() must be called again.
static void ShutDown();
static PlatformClientPosix* GetInstance() { return instance_; }
// This method is thread-safe.
+ // FIXME: Rename to GetTlsDataRouter()
TlsDataRouterPosix* tls_data_router();
// This method is thread-safe.
+ // FIXME: Rename to GetUdpSocketReader()
UdpSocketReaderPosix* udp_socket_reader();
// Returns the TaskRunner associated with this PlatformClient.
@@ -62,29 +77,19 @@ class PlatformClientPosix {
TaskRunner* GetTaskRunner();
protected:
- // The TaskRunner parameter here is a user-provided TaskRunner to be used
- // instead of the one normally created within PlatformClientPosix. Ownership
- // of the TaskRunner is transferred to this class.
- PlatformClientPosix(Clock::duration networking_operation_timeout,
- Clock::duration networking_loop_interval,
- std::unique_ptr<TaskRunner> task_runner);
-
- // This method is expected to be called in order to set the singleton instance
- // (typically from the Create() method). It should only be called from the
- // embedder thread. Client should be a new instance create via 'new' and
- // ownership of this instance will be transferred to this class.
- // NOTE: This method is NOT thread safe and should only be called from the
- // embedder thread.
- static void SetInstance(PlatformClientPosix* client);
-
// Called by ShutDown().
~PlatformClientPosix();
+ static void SetInstance(PlatformClientPosix* client);
+
private:
- // Called by Create().
PlatformClientPosix(Clock::duration networking_operation_timeout,
Clock::duration networking_loop_interval);
+ PlatformClientPosix(Clock::duration networking_operation_timeout,
+ Clock::duration networking_loop_interval,
+ std::unique_ptr<TaskRunnerImpl> task_runner);
+
// This method is thread-safe.
SocketHandleWaiterPosix* socket_handle_waiter();
@@ -97,8 +102,8 @@ class PlatformClientPosix {
// Instance objects with threads are created at object-creation time.
// NOTE: Delayed instantiation of networking_loop_ may be useful in future.
OperationLoop networking_loop_;
- absl::optional<TaskRunnerImpl> owned_task_runner_;
- std::unique_ptr<TaskRunner> caller_provided_task_runner_;
+
+ std::unique_ptr<TaskRunnerImpl> task_runner_;
// Track whether the associated instance variable has been created yet.
std::atomic_bool waiter_created_{false};
@@ -118,14 +123,13 @@ class PlatformClientPosix {
// Threads for running TaskRunner and OperationLoop instances.
// NOTE: These must be declared last to avoid nondterministic failures.
std::thread networking_loop_thread_;
- std::thread task_runner_thread_;
+ absl::optional<std::thread> task_runner_thread_;
static PlatformClientPosix* instance_;
OSP_DISALLOW_COPY_AND_ASSIGN(PlatformClientPosix);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_PLATFORM_CLIENT_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc
new file mode 100644
index 00000000000..c99a2dd928b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc
@@ -0,0 +1,56 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "platform/impl/scoped_wake_lock_linux.h"
+
+#include "platform/api/task_runner.h"
+#include "platform/impl/platform_client_posix.h"
+#include "util/logging.h"
+
+namespace openscreen {
+
+int ScopedWakeLockLinux::reference_count_ = 0;
+
+std::unique_ptr<ScopedWakeLock> ScopedWakeLock::Create() {
+ return std::make_unique<ScopedWakeLockLinux>();
+}
+
+namespace {
+
+TaskRunner* GetTaskRunner() {
+ auto* const platform_client = PlatformClientPosix::GetInstance();
+ OSP_DCHECK(platform_client);
+ auto* const task_runner = platform_client->GetTaskRunner();
+ OSP_DCHECK(task_runner);
+ return task_runner;
+}
+
+} // namespace
+
+ScopedWakeLockLinux::ScopedWakeLockLinux() : ScopedWakeLock() {
+ OSP_DCHECK(GetTaskRunner()->IsRunningOnTaskRunner());
+ if (reference_count_++ == 0) {
+ AcquireWakeLock();
+ }
+}
+
+ScopedWakeLockLinux::~ScopedWakeLockLinux() {
+ GetTaskRunner()->PostTask([] {
+ if (--reference_count_ == 0) {
+ ReleaseWakeLock();
+ }
+ });
+}
+
+// static
+void ScopedWakeLockLinux::AcquireWakeLock() {
+ OSP_UNIMPLEMENTED();
+}
+
+// static
+void ScopedWakeLockLinux::ReleaseWakeLock() {
+ OSP_UNIMPLEMENTED();
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h
new file mode 100644
index 00000000000..81117619ea4
--- /dev/null
+++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.h
@@ -0,0 +1,27 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_
+#define PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_
+
+#include "platform/api/scoped_wake_lock.h"
+
+namespace openscreen {
+
+class ScopedWakeLockLinux : public ScopedWakeLock {
+ public:
+ ScopedWakeLockLinux();
+ ~ScopedWakeLockLinux() override;
+
+ private:
+ // TODO(jophba): implement linux wake lock.
+ static void AcquireWakeLock();
+ static void ReleaseWakeLock();
+
+ static int reference_count_;
+};
+
+} // namespace openscreen
+
+#endif // PLATFORM_IMPL_SCOPED_WAKE_LOCK_LINUX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc
index 70b9e2ddc8d..3cd7f11ae8c 100644
--- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc
@@ -11,7 +11,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
ScopedWakeLockMac::LockState ScopedWakeLockMac::lock_state_{};
@@ -70,5 +69,4 @@ void ScopedWakeLockMac::ReleaseWakeLock() {
OSP_DCHECK_EQ(result, kIOReturnSuccess);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h
index 4e7d2031266..1d1a2f416d9 100644
--- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h
+++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.h
@@ -10,7 +10,6 @@
#include "platform/api/scoped_wake_lock.h"
namespace openscreen {
-namespace platform {
class ScopedWakeLockMac : public ScopedWakeLock {
public:
@@ -29,7 +28,6 @@ class ScopedWakeLockMac : public ScopedWakeLock {
static LockState lock_state_;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SCOPED_WAKE_LOCK_MAC_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc
index d91975c6214..b45b4055873 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.cc
@@ -4,24 +4,20 @@
#include "platform/impl/socket_address_posix.h"
+#include <cstring>
#include <vector>
#include "util/logging.h"
namespace openscreen {
-namespace platform {
SocketAddressPosix::SocketAddressPosix(const struct sockaddr& address) {
if (address.sa_family == AF_INET) {
memcpy(&internal_address_, &address, sizeof(struct sockaddr_in));
- endpoint_.address = IPAddress(IPAddress::Version::kV4,
- reinterpret_cast<const uint8_t*>(
- &internal_address_.v4.sin_addr.s_addr));
- endpoint_.port = ntohs(internal_address_.v4.sin_port);
+ RecomputeEndpoint(IPAddress::Version::kV4);
} else if (address.sa_family == AF_INET6) {
memcpy(&internal_address_, &address, sizeof(struct sockaddr_in6));
- endpoint_.address = IPAddress(internal_address_.v6.sin6_addr.s6_addr);
- endpoint_.port = ntohs(internal_address_.v6.sin6_port);
+ RecomputeEndpoint(IPAddress::Version::kV6);
} else {
OSP_NOTREACHED() << "Unknown address type";
}
@@ -55,7 +51,7 @@ struct sockaddr* SocketAddressPosix::address() {
default:
OSP_NOTREACHED();
return nullptr;
- };
+ }
}
const struct sockaddr* SocketAddressPosix::address() const {
@@ -81,5 +77,25 @@ socklen_t SocketAddressPosix::size() const {
return 0;
}
}
-} // namespace platform
+
+void SocketAddressPosix::RecomputeEndpoint() {
+ RecomputeEndpoint(endpoint_.address.version());
+}
+
+void SocketAddressPosix::RecomputeEndpoint(IPAddress::Version version) {
+ switch (version) {
+ case IPAddress::Version::kV4:
+ endpoint_.address = IPAddress(IPAddress::Version::kV4,
+ reinterpret_cast<const uint8_t*>(
+ &internal_address_.v4.sin_addr.s_addr));
+ endpoint_.port = ntohs(internal_address_.v4.sin_port);
+ break;
+ case IPAddress::Version::kV6:
+ endpoint_.address = IPAddress(IPAddress::Version::kV6,
+ internal_address_.v6.sin6_addr.s6_addr);
+ endpoint_.port = ntohs(internal_address_.v6.sin6_port);
+ break;
+ }
+}
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h
index 17f0e82a891..07e49940d86 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix.h
@@ -17,7 +17,6 @@
#include "platform/base/ip_address.h"
namespace openscreen {
-namespace platform {
class SocketAddressPosix {
public:
@@ -35,7 +34,13 @@ class SocketAddressPosix {
IPAddress::Version version() const { return endpoint_.address.version(); }
IPEndpoint endpoint() const { return endpoint_; }
+ // Recomputes |endpoint_| if |internal_address_| is written to directly, e.g.
+ // by a system call.
+ void RecomputeEndpoint();
+
private:
+ void RecomputeEndpoint(IPAddress::Version version);
+
// The way the sockaddr_* family works in POSIX is pretty unintuitive. The
// sockaddr_in and sockaddr_in6 structs can be reinterpreted as type
// sockaddr, however they don't have a common parent--the types are unrelated.
@@ -50,7 +55,6 @@ class SocketAddressPosix {
IPEndpoint endpoint_;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SOCKET_ADDRESS_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc
index 7d79c0c8f04..e8b81c0962a 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_address_posix_unittest.cc
@@ -8,7 +8,6 @@
#include "gtest/gtest.h"
namespace openscreen {
-namespace platform {
TEST(SocketAddressPosixTest, IPv4SocketAddressConvertsSuccessfully) {
const SocketAddressPosix address(IPEndpoint{{1, 2, 3, 4}, 80});
@@ -26,8 +25,8 @@ TEST(SocketAddressPosixTest, IPv4SocketAddressConvertsSuccessfully) {
}
TEST(SocketAddressPosixTest, IPv6SocketAddressConvertsSuccessfully) {
- const SocketAddressPosix address(
- IPEndpoint{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 80});
+ const SocketAddressPosix address(IPEndpoint{
+ {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}, 80});
const sockaddr_in6* v6_address =
reinterpret_cast<const sockaddr_in6*>(address.address());
@@ -85,9 +84,8 @@ TEST(SocketAddressPosixTest, IPv6ConvertsSuccessfully) {
EXPECT_THAT(v6_address->sin6_addr.s6_addr,
testing::ElementsAreArray(kExpectedAddress));
IPEndpoint expected_address{
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, 80};
+ {0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}, 80};
EXPECT_EQ(address_posix.endpoint(), expected_address);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle.h
index 6a775f1a1bf..fca8192e941 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle.h
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle.h
@@ -8,7 +8,6 @@
#include <cstdlib>
namespace openscreen {
-namespace platform {
// A SocketHandle is the handle used to access a Socket by the underlying
// platform.
@@ -23,7 +22,6 @@ inline bool operator!=(const SocketHandle& lhs, const SocketHandle& rhs) {
return !(lhs == rhs);
}
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SOCKET_HANDLE_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc
index 7d99ca32ab3..dfc9c5bb34c 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.cc
@@ -8,7 +8,6 @@
#include <functional>
namespace openscreen {
-namespace platform {
SocketHandle::SocketHandle(int descriptor) : fd(descriptor) {}
@@ -20,5 +19,4 @@ size_t SocketHandleHash::operator()(const SocketHandle& handle) const {
return std::hash<int>()(handle.fd);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h
index 59ff670d913..13f977c33d8 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_posix.h
@@ -8,14 +8,12 @@
#include "platform/impl/socket_handle.h"
namespace openscreen {
-namespace platform {
struct SocketHandle {
explicit SocketHandle(int descriptor);
int fd;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SOCKET_HANDLE_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc
index e3731b34850..1560aa6c6dd 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.cc
@@ -11,7 +11,9 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
+
+SocketHandleWaiter::SocketHandleWaiter(ClockNowFunctionPtr now_function)
+ : now_function_(now_function) {}
void SocketHandleWaiter::Subscribe(Subscriber* subscriber,
SocketHandleRef handle) {
@@ -51,6 +53,7 @@ void SocketHandleWaiter::OnHandleDeletion(Subscriber* subscriber,
if (!disable_locking_for_testing) {
handles_being_deleted_.push_back(handle);
+ OSP_DVLOG << "Starting to block for handle deletion";
// This code will allow us to block completion of the socket destructor
// (and subsequent invalidation of pointers to this socket) until we no
// longer are waiting on a SELECT(...) call to it, since we only signal
@@ -60,25 +63,36 @@ void SocketHandleWaiter::OnHandleDeletion(Subscriber* subscriber,
handles_being_deleted_.end(),
handle) == handles_being_deleted_.end();
});
+ OSP_DVLOG << "\tDone blocking for handle deletion!";
}
}
}
void SocketHandleWaiter::ProcessReadyHandles(
- const std::vector<SocketHandleRef>& handles) {
- std::lock_guard<std::mutex> lock(mutex_);
- for (const SocketHandleRef& handle : handles) {
- auto iterator = handle_mappings_.find(handle);
- if (iterator == handle_mappings_.end()) {
- // This is OK: SocketHandle was deleted in the meantime.
- continue;
+ const std::vector<HandleWithSubscriber>& handles,
+ Clock::duration timeout) {
+ Clock::time_point start_time = now_function_();
+ bool processed_one = false;
+ // TODO(btolsch): Track explicit or implicit time since last handled on each
+ // watched handle so we can sort by it here for better fairness.
+ for (const HandleWithSubscriber& handle : handles) {
+ Clock::time_point current_time = now_function_();
+ if (processed_one && (current_time - start_time) > timeout) {
+ return;
}
- iterator->second->ProcessReadyHandle(handle);
+ processed_one = true;
+ handle.subscriber->ProcessReadyHandle(handle.handle);
+
+ current_time = now_function_();
+ if ((current_time - start_time) > timeout) {
+ return;
+ }
}
}
Error SocketHandleWaiter::ProcessHandles(Clock::duration timeout) {
+ Clock::time_point start_time = now_function_();
std::vector<SocketHandleRef> handles;
{
std::lock_guard<std::mutex> lock(mutex_);
@@ -90,22 +104,37 @@ Error SocketHandleWaiter::ProcessHandles(Clock::duration timeout) {
}
}
+ Clock::time_point current_time = now_function_();
+ Clock::duration remaining_timeout = timeout - (current_time - start_time);
ErrorOr<std::vector<SocketHandleRef>> changed_handles =
- AwaitSocketsReadable(handles, timeout);
+ AwaitSocketsReadable(handles, remaining_timeout);
+ std::vector<HandleWithSubscriber> ready_handles;
{
std::lock_guard<std::mutex> lock(mutex_);
handles_being_deleted_.clear();
handle_deletion_block_.notify_all();
- }
+ if (changed_handles) {
+ auto& ch = changed_handles.value();
+ ready_handles.reserve(ch.size());
+ for (const auto& handle : ch) {
+ auto mapping_it = handle_mappings_.find(handle);
+ if (mapping_it != handle_mappings_.end()) {
+ ready_handles.push_back(
+ HandleWithSubscriber{handle, mapping_it->second});
+ }
+ }
+ }
- if (changed_handles.is_error()) {
- return changed_handles.error();
- }
+ if (changed_handles.is_error()) {
+ return changed_handles.error();
+ }
- ProcessReadyHandles(changed_handles.value());
+ current_time = now_function_();
+ remaining_timeout = timeout - (current_time - start_time);
+ ProcessReadyHandles(ready_handles, remaining_timeout);
+ }
return Error::None();
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h
index 1b97abd4f7e..1ad81526058 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter.h
@@ -18,7 +18,6 @@
#include "platform/impl/socket_handle.h"
namespace openscreen {
-namespace platform {
// The class responsible for calling platform-level method to watch UDP sockets
// for available read data. Reading from these sockets is handled at a higher
@@ -36,7 +35,7 @@ class SocketHandleWaiter {
virtual void ProcessReadyHandle(SocketHandleRef handle) = 0;
};
- SocketHandleWaiter() = default;
+ explicit SocketHandleWaiter(ClockNowFunctionPtr now_function);
virtual ~SocketHandleWaiter() = default;
// Start notifying |subscriber| whenever |handle| has an event. May be called
@@ -73,8 +72,15 @@ class SocketHandleWaiter {
const Clock::duration& timeout) = 0;
private:
- // Call the subscriber associated with each changed handle.
- void ProcessReadyHandles(const std::vector<SocketHandleRef>& handles);
+ struct HandleWithSubscriber {
+ SocketHandleRef handle;
+ Subscriber* subscriber;
+ };
+
+ // Call the subscriber associated with each changed handle. Handles are only
+ // processed until |timeout| is exceeded. Must be called with |mutex_| held.
+ void ProcessReadyHandles(const std::vector<HandleWithSubscriber>& handles,
+ Clock::duration timeout);
// Guards against concurrent access to all other class data members.
std::mutex mutex_;
@@ -90,9 +96,10 @@ class SocketHandleWaiter {
// that is watching them.
std::unordered_map<SocketHandleRef, Subscriber*, SocketHandleHash>
handle_mappings_;
+
+ const ClockNowFunctionPtr now_function_;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SOCKET_HANDLE_WAITER_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc
index 0cfca99bdf6..38762bf618d 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.cc
@@ -16,9 +16,10 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
-SocketHandleWaiterPosix::SocketHandleWaiterPosix() = default;
+SocketHandleWaiterPosix::SocketHandleWaiterPosix(
+ ClockNowFunctionPtr now_function)
+ : SocketHandleWaiter(now_function) {}
SocketHandleWaiterPosix::~SocketHandleWaiterPosix() = default;
@@ -27,9 +28,14 @@ SocketHandleWaiterPosix::AwaitSocketsReadable(
const std::vector<SocketHandleRef>& socket_handles,
const Clock::duration& timeout) {
int max_fd = -1;
- FD_ZERO(&read_handles_);
+ fd_set read_handles;
+ fd_set write_handles;
+
+ FD_ZERO(&read_handles);
+ FD_ZERO(&write_handles);
for (const SocketHandle& handle : socket_handles) {
- FD_SET(handle.fd, &read_handles_);
+ FD_SET(handle.fd, &read_handles);
+ FD_SET(handle.fd, &write_handles);
max_fd = std::max(max_fd, handle.fd);
}
if (max_fd < 0) {
@@ -37,10 +43,13 @@ SocketHandleWaiterPosix::AwaitSocketsReadable(
}
struct timeval tv = ToTimeval(timeout);
- // This value is set to 'max_fd + 1' by convention. For more information, see:
+ // This value is set to 'max_fd + 1' by convention. Also, select() is
+ // level-triggered so incomplete reads/writes by the caller are fine and will
+ // be picked up again on the next select() call. For more information, see:
// http://man7.org/linux/man-pages/man2/select.2.html
int max_fd_to_watch = max_fd + 1;
- const int rv = select(max_fd_to_watch, &read_handles_, nullptr, nullptr, &tv);
+ const int rv =
+ select(max_fd_to_watch, &read_handles, &write_handles, nullptr, &tv);
if (rv == -1) {
// This is the case when an error condition is hit within the select(...)
// command.
@@ -52,7 +61,9 @@ SocketHandleWaiterPosix::AwaitSocketsReadable(
std::vector<SocketHandleRef> changed_handles;
for (const SocketHandleRef& handle : socket_handles) {
- if (FD_ISSET(handle.get().fd, &read_handles_)) {
+ if (FD_ISSET(handle.get().fd, &read_handles) ||
+ FD_ISSET(handle.get().fd, &write_handles)) {
+ // TODO(btolsch): Distinguish between read and write.
changed_handles.push_back(handle);
}
}
@@ -74,5 +85,4 @@ void SocketHandleWaiterPosix::RequestStopSoon() {
is_running_.store(false);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h
index 2a8b868300a..e2b78e5d8e0 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h
@@ -13,13 +13,12 @@
#include "platform/impl/socket_handle_waiter.h"
namespace openscreen {
-namespace platform {
class SocketHandleWaiterPosix : public SocketHandleWaiter {
public:
using SocketHandleRef = SocketHandleWaiter::SocketHandleRef;
- SocketHandleWaiterPosix();
+ explicit SocketHandleWaiterPosix(ClockNowFunctionPtr now_function);
~SocketHandleWaiterPosix() override;
// Runs the Wait function in a loop until the below RequestStopSoon function
@@ -35,13 +34,10 @@ class SocketHandleWaiterPosix : public SocketHandleWaiter {
const Clock::duration& timeout) override;
private:
- fd_set read_handles_;
-
// Atomic so that we can perform atomic exchanges.
std::atomic_bool is_running_;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_SOCKET_HANDLE_WAITER_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc
index fcc249e53ee..bcd647070bc 100644
--- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix_unittest.cc
@@ -14,9 +14,9 @@
#include "gtest/gtest.h"
#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/timeval_posix.h"
+#include "platform/test/fake_clock.h"
namespace openscreen {
-namespace platform {
namespace {
using namespace ::testing;
@@ -32,10 +32,14 @@ class TestingSocketHandleWaiter : public SocketHandleWaiter {
public:
using SocketHandleRef = SocketHandleWaiter::SocketHandleRef;
+ TestingSocketHandleWaiter() : SocketHandleWaiter(&FakeClock::now) {}
+
MOCK_METHOD2(
AwaitSocketsReadable,
ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&,
const Clock::duration&));
+
+ FakeClock fake_clock{Clock::time_point{Clock::duration{1234567}}};
};
} // namespace
@@ -88,5 +92,4 @@ TEST(SocketHandleWaiterTest, WatchedSocketsReturnedToCorrectSubscribers) {
waiter.ProcessHandles(Clock::duration{0});
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket.h b/chromium/third_party/openscreen/src/platform/impl/stream_socket.h
index 7278e05b8dc..b4536e274bd 100644
--- a/chromium/third_party/openscreen/src/platform/impl/stream_socket.h
+++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket.h
@@ -17,7 +17,6 @@
#include "platform/impl/socket_handle.h"
namespace openscreen {
-namespace platform {
// StreamSocket is an incomplete abstraction of synchronous platform methods for
// creating, initializing, and closing stream sockets. Callers can use this
@@ -68,7 +67,6 @@ class StreamSocket {
virtual IPAddress::Version version() const = 0;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_STREAM_SOCKET_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc
index b60e82ae0d9..c753344dd9b 100644
--- a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.cc
@@ -7,12 +7,12 @@
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
+#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
namespace openscreen {
-namespace platform {
namespace {
constexpr int kDefaultMaxBacklogSize = 64;
@@ -38,10 +38,15 @@ StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint)
local_address_(local_endpoint) {}
StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address,
+ IPEndpoint remote_address,
int file_descriptor)
: handle_(file_descriptor),
version_(local_address.version()),
- local_address_(local_address) {}
+ local_address_(local_address),
+ remote_address_(remote_address),
+ state_(SocketState::kConnected) {
+ EnsureInitialized();
+}
StreamSocketPosix::~StreamSocketPosix() {
if (state_ == SocketState::kConnected) {
@@ -74,11 +79,14 @@ ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
const int new_file_descriptor =
accept(handle_.fd, new_remote_address.address(), &remote_address_size);
if (new_file_descriptor == kUnsetHandleFd) {
- return CloseOnError(Error::Code::kSocketAcceptFailure);
+ return CloseOnError(
+ Error(Error::Code::kSocketAcceptFailure, strerror(errno)));
}
+ new_remote_address.RecomputeEndpoint();
return ErrorOr<std::unique_ptr<StreamSocket>>(
- std::make_unique<StreamSocketPosix>(new_remote_address,
+ std::make_unique<StreamSocketPosix>(local_address_.value(),
+ new_remote_address.endpoint(),
new_file_descriptor));
}
@@ -97,7 +105,8 @@ Error StreamSocketPosix::Bind() {
if (bind(handle_.fd, local_address_.value().address(),
local_address_.value().size()) != 0) {
- return CloseOnError(Error::Code::kSocketBindFailure);
+ return CloseOnError(
+ Error(Error::Code::kSocketBindFailure, strerror(errno)));
}
is_bound_ = true;
@@ -134,8 +143,10 @@ Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
}
SocketAddressPosix address(remote_endpoint);
- if (connect(handle_.fd, address.address(), address.size()) != 0) {
- return CloseOnError(Error::Code::kSocketConnectFailure);
+ int ret = connect(handle_.fd, address.address(), address.size());
+ if (ret != 0 && errno != EINPROGRESS) {
+ return CloseOnError(
+ Error(Error::Code::kSocketConnectFailure, strerror(errno)));
}
if (!is_bound_) {
@@ -143,13 +154,14 @@ Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
return CloseOnError(Error::Code::kSocketInvalidState);
}
- struct sockaddr address;
+ struct sockaddr_in6 address;
socklen_t size = sizeof(address);
- if (getsockname(handle_.fd, &address, &size) != 0) {
+ if (getsockname(handle_.fd, reinterpret_cast<struct sockaddr*>(&address),
+ &size) != 0) {
return CloseOnError(Error::Code::kSocketConnectFailure);
}
- local_address_.emplace(address);
+ local_address_.emplace(reinterpret_cast<struct sockaddr&>(address));
is_bound_ = true;
}
@@ -168,7 +180,8 @@ Error StreamSocketPosix::Listen(int max_backlog_size) {
}
if (listen(handle_.fd, max_backlog_size) != 0) {
- return CloseOnError(Error::Code::kSocketListenFailure);
+ return CloseOnError(
+ Error(Error::Code::kSocketListenFailure, strerror(errno)));
}
return Error::None();
@@ -201,7 +214,7 @@ bool StreamSocketPosix::EnsureInitialized() {
return Initialize() == Error::None();
}
- return false;
+ return handle_.fd != kUnsetHandleFd && is_initialized_;
}
Error StreamSocketPosix::Initialize() {
@@ -209,40 +222,43 @@ Error StreamSocketPosix::Initialize() {
return Error::Code::kItemAlreadyExists;
}
- int domain;
- switch (version_) {
- case IPAddress::Version::kV4:
- domain = AF_INET;
- break;
- case IPAddress::Version::kV6:
- domain = AF_INET6;
- break;
- }
+ int fd = handle_.fd;
+ if (fd == kUnsetHandleFd) {
+ int domain;
+ switch (version_) {
+ case IPAddress::Version::kV4:
+ domain = AF_INET;
+ break;
+ case IPAddress::Version::kV6:
+ domain = AF_INET6;
+ break;
+ }
- const int file_descriptor = socket(domain, SOCK_STREAM, 0);
- if (file_descriptor == kUnsetHandleFd) {
- last_error_code_ = Error::Code::kSocketInvalidState;
- return Error::Code::kSocketInvalidState;
+ fd = socket(domain, SOCK_STREAM, 0);
+ if (fd == kUnsetHandleFd) {
+ last_error_code_ = Error::Code::kSocketInvalidState;
+ return Error::Code::kSocketInvalidState;
+ }
}
- const int current_flags = fcntl(file_descriptor, F_GETFL, 0);
- if (fcntl(file_descriptor, F_SETFL, current_flags | O_NONBLOCK) == -1) {
- close(file_descriptor);
+ const int current_flags = fcntl(fd, F_GETFL, 0);
+ if (fcntl(fd, F_SETFL, current_flags | O_NONBLOCK) == -1) {
+ close(fd);
last_error_code_ = Error::Code::kSocketInvalidState;
return Error::Code::kSocketInvalidState;
}
- handle_.fd = file_descriptor;
+ handle_.fd = fd;
is_initialized_ = true;
// last_error_code_ should still be Error::None().
return Error::None();
}
-Error StreamSocketPosix::CloseOnError(Error::Code error_code) {
- last_error_code_ = error_code;
+Error StreamSocketPosix::CloseOnError(Error error) {
+ last_error_code_ = error.code();
Close();
state_ = SocketState::kClosed;
- return error_code;
+ return error;
}
// If is_open is false, the socket has either not been initialized
@@ -251,5 +267,4 @@ Error StreamSocketPosix::ReportSocketClosedError() {
last_error_code_ = Error::Code::kSocketClosedFailure;
return Error::Code::kSocketClosedFailure;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h
index a990ca2b2d6..93f81744329 100644
--- a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h
@@ -15,16 +15,17 @@
#include "platform/impl/socket_address_posix.h"
#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/stream_socket.h"
-#include "platform/impl/weak_ptr.h"
+#include "util/weak_ptr.h"
namespace openscreen {
-namespace platform {
class StreamSocketPosix : public StreamSocket {
public:
StreamSocketPosix(IPAddress::Version version);
StreamSocketPosix(const IPEndpoint& local_endpoint);
- StreamSocketPosix(SocketAddressPosix local_address, int file_descriptor);
+ StreamSocketPosix(SocketAddressPosix local_address,
+ IPEndpoint remote_address,
+ int file_descriptor);
// StreamSocketPosix is non-copyable, due to directly managing the file
// descriptor.
@@ -58,7 +59,7 @@ class StreamSocketPosix : public StreamSocket {
bool EnsureInitialized();
Error Initialize();
- Error CloseOnError(Error::Code error_code);
+ Error CloseOnError(Error error);
Error ReportSocketClosedError();
constexpr static int kUnsetHandleFd = -1;
@@ -81,7 +82,6 @@ class StreamSocketPosix : public StreamSocket {
WeakPtrFactory<StreamSocketPosix> weak_factory_{this};
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_STREAM_SOCKET_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc
index 8ff92b34b2e..ac4a49a86f1 100644
--- a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc
@@ -4,14 +4,33 @@
#include "platform/impl/task_runner.h"
+#include <csignal>
#include <thread>
#include "util/logging.h"
namespace openscreen {
-namespace platform {
-TaskRunnerImpl::TaskRunnerImpl(platform::ClockNowFunctionPtr now_function,
+namespace {
+
+// This is mutated by the signal handler installed by RunUntilSignaled(), and is
+// checked by RunUntilStopped().
+//
+// Per the C++14 spec, passing visible changes to memory between a signal
+// handler and a program thread must be done through a volatile variable.
+volatile enum {
+ kNotRunning,
+ kNotSignaled,
+ kSignaled
+} g_signal_state = kNotRunning;
+
+void OnReceivedSignal(int signal) {
+ g_signal_state = kSignaled;
+}
+
+} // namespace
+
+TaskRunnerImpl::TaskRunnerImpl(ClockNowFunctionPtr now_function,
TaskWaiter* event_waiter,
Clock::duration waiter_timeout)
: now_function_(now_function),
@@ -37,8 +56,12 @@ void TaskRunnerImpl::PostPackagedTask(Task task) {
void TaskRunnerImpl::PostPackagedTaskWithDelay(Task task,
Clock::duration delay) {
std::lock_guard<std::mutex> lock(task_mutex_);
- delayed_tasks_.emplace(
- std::make_pair(now_function_() + delay, std::move(task)));
+ if (delay <= Clock::duration::zero()) {
+ tasks_.emplace_back(std::move(task));
+ } else {
+ delayed_tasks_.emplace(
+ std::make_pair(now_function_() + delay, std::move(task)));
+ }
if (task_waiter_) {
task_waiter_->OnTaskPosted();
} else {
@@ -56,12 +79,16 @@ void TaskRunnerImpl::RunUntilStopped() {
is_running_ = true;
// Main loop: Run until the |is_running_| flag is set back to false by the
- // "quit task" posted by RequestStopSoon().
+ // "quit task" posted by RequestStopSoon(), or the process received a
+ // termination signal.
while (is_running_) {
ScheduleDelayedTasks();
if (GrabMoreRunnableTasks()) {
RunRunnableTasks();
}
+ if (g_signal_state == kSignaled) {
+ is_running_ = false;
+ }
}
// Flushing phase: Ensure all immediately-runnable tasks are run before
@@ -80,6 +107,20 @@ void TaskRunnerImpl::RunUntilStopped() {
task_runner_thread_id_ = std::thread::id();
}
+void TaskRunnerImpl::RunUntilSignaled() {
+ OSP_CHECK_EQ(g_signal_state, kNotRunning)
+ << __func__ << " may not be invoked concurrently.";
+ g_signal_state = kNotSignaled;
+ const auto old_sigint_handler = std::signal(SIGINT, &OnReceivedSignal);
+ const auto old_sigterm_handler = std::signal(SIGTERM, &OnReceivedSignal);
+
+ RunUntilStopped();
+
+ std::signal(SIGINT, old_sigint_handler);
+ std::signal(SIGTERM, old_sigterm_handler);
+ g_signal_state = kNotRunning;
+}
+
void TaskRunnerImpl::RequestStopSoon() {
PostTask([this]() { is_running_ = false; });
}
@@ -143,5 +184,4 @@ bool TaskRunnerImpl::GrabMoreRunnableTasks() {
return false;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.h b/chromium/third_party/openscreen/src/platform/impl/task_runner.h
index 35da181b9ae..629e23cb7b2 100644
--- a/chromium/third_party/openscreen/src/platform/impl/task_runner.h
+++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.h
@@ -22,7 +22,6 @@
#include "util/trace_logging.h"
namespace openscreen {
-namespace platform {
class TaskRunnerImpl final : public TaskRunner {
public:
@@ -49,7 +48,7 @@ class TaskRunnerImpl final : public TaskRunner {
};
explicit TaskRunnerImpl(
- platform::ClockNowFunctionPtr now_function,
+ ClockNowFunctionPtr now_function,
TaskWaiter* event_waiter = nullptr,
Clock::duration waiter_timeout = std::chrono::milliseconds(100));
@@ -64,6 +63,11 @@ class TaskRunnerImpl final : public TaskRunner {
// called.
void RunUntilStopped();
+ // Blocks the current thread, executing tasks from the queue with the desired
+ // timing; and does not return until some time after the current process is
+ // signaled with SIGINT or SIGTERM, or after RequestStopSoon() is called.
+ void RunUntilSignaled();
+
// Thread-safe method for requesting the TaskRunner to stop running after all
// non-delayed tasks in the queue have run. This behavior allows final
// clean-up tasks to be executed before the TaskRunner stops.
@@ -83,7 +87,7 @@ class TaskRunnerImpl final : public TaskRunner {
// used. This simplifies switching between 'Task' and 'TaskWithMetadata'
// based on the compilation flag.
TaskWithMetadata(Task task)
- : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY){};
+ : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY) {}
void operator()() {
TRACE_SET_HIERARCHY(trace_ids_);
@@ -111,7 +115,7 @@ class TaskRunnerImpl final : public TaskRunner {
// transferred.
bool GrabMoreRunnableTasks();
- const platform::ClockNowFunctionPtr now_function_;
+ const ClockNowFunctionPtr now_function_;
// Flag that indicates whether the task runner loop should continue. This is
// only meant to be read/written on the thread executing RunUntilStopped().
@@ -141,7 +145,6 @@ class TaskRunnerImpl final : public TaskRunner {
OSP_DISALLOW_COPY_AND_ASSIGN(TaskRunnerImpl);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TASK_RUNNER_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc
index ac4331964fc..d7535cf34b2 100644
--- a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc
@@ -13,7 +13,6 @@
#include "platform/test/fake_clock.h"
namespace openscreen {
-namespace platform {
namespace {
using namespace ::testing;
@@ -31,7 +30,7 @@ void WaitUntilCondition(std::function<bool()> predicate) {
class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter {
public:
- explicit FakeTaskWaiter(platform::ClockNowFunctionPtr now_function)
+ explicit FakeTaskWaiter(ClockNowFunctionPtr now_function)
: now_function_(now_function) {}
~FakeTaskWaiter() override = default;
@@ -60,7 +59,7 @@ class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter {
}
private:
- const platform::ClockNowFunctionPtr now_function_;
+ const ClockNowFunctionPtr now_function_;
TaskRunnerImpl* task_runner_;
std::atomic<bool> has_event_{false};
std::atomic<bool> waiting_{false};
@@ -69,7 +68,7 @@ class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter {
class TaskRunnerWithWaiterFactory {
public:
static std::unique_ptr<TaskRunnerImpl> Create(
- platform::ClockNowFunctionPtr now_function) {
+ ClockNowFunctionPtr now_function) {
fake_waiter = std::make_unique<FakeTaskWaiter>(now_function);
auto runner = std::make_unique<TaskRunnerImpl>(
now_function, fake_waiter.get(), std::chrono::hours(1));
@@ -86,7 +85,7 @@ std::unique_ptr<FakeTaskWaiter> TaskRunnerWithWaiterFactory::fake_waiter;
} // anonymous namespace
TEST(TaskRunnerImplTest, TaskRunnerExecutesTaskAndStops) {
- FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))};
+ FakeClock fake_clock{Clock::time_point(milliseconds(1337))};
TaskRunnerImpl runner(&fake_clock.now);
std::string ran_tasks = "";
@@ -98,7 +97,7 @@ TEST(TaskRunnerImplTest, TaskRunnerExecutesTaskAndStops) {
}
TEST(TaskRunnerImplTest, TaskRunnerRunsDelayedTasksInOrder) {
- FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))};
+ FakeClock fake_clock{Clock::time_point(milliseconds(1337))};
TaskRunnerImpl runner(&fake_clock.now);
std::thread t([&runner] { runner.RunUntilStopped(); });
@@ -126,7 +125,7 @@ TEST(TaskRunnerImplTest, TaskRunnerRunsDelayedTasksInOrder) {
}
TEST(TaskRunnerImplTest, SingleThreadedTaskRunnerRunsSequentially) {
- FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))};
+ FakeClock fake_clock{Clock::time_point(milliseconds(1337))};
TaskRunnerImpl runner(&fake_clock.now);
std::string ran_tasks;
@@ -149,7 +148,7 @@ TEST(TaskRunnerImplTest, SingleThreadedTaskRunnerRunsSequentially) {
}
TEST(TaskRunnerImplTest, RunsAllImmediateTasksBeforeStopping) {
- FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))};
+ FakeClock fake_clock{Clock::time_point(milliseconds(1337))};
TaskRunnerImpl runner(&fake_clock.now);
std::string result;
@@ -181,7 +180,7 @@ TEST(TaskRunnerImplTest, RunsAllImmediateTasksBeforeStopping) {
}
TEST(TaskRunnerImplTest, TaskRunnerIsStableWithLotsOfTasks) {
- FakeClock fake_clock{platform::Clock::time_point(milliseconds(1337))};
+ FakeClock fake_clock{Clock::time_point(milliseconds(1337))};
TaskRunnerImpl runner(&fake_clock.now);
const int kNumberOfTasks = 500;
@@ -200,7 +199,7 @@ TEST(TaskRunnerImplTest, TaskRunnerIsStableWithLotsOfTasks) {
}
TEST(TaskRunnerImplTest, TaskRunnerDelayedTasksDontBlockImmediateTasks) {
- TaskRunnerImpl runner(platform::Clock::now);
+ TaskRunnerImpl runner(Clock::now);
std::string ran_tasks;
const auto task = [&ran_tasks] { ran_tasks += "1"; };
@@ -282,5 +281,4 @@ class RepeatedClass {
std::atomic<int> execution_count{0};
};
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc
index 573369a68c6..79597e1d655 100644
--- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc
@@ -9,7 +9,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
bool TextTraceLoggingPlatform::IsTraceLoggingEnabled(
TraceCategory::Value category) {
@@ -19,12 +18,10 @@ bool TextTraceLoggingPlatform::IsTraceLoggingEnabled(
}
TextTraceLoggingPlatform::TextTraceLoggingPlatform() {
- OSP_DCHECK(!GetTracingDestination());
StartTracing(this);
}
TextTraceLoggingPlatform::~TextTraceLoggingPlatform() {
- OSP_DCHECK_EQ(GetTracingDestination(), this);
StopTracing();
}
@@ -69,5 +66,4 @@ void TextTraceLoggingPlatform::LogAsyncEnd(const uint32_t line,
<< timestamp << ") " << error;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h
index 3d2a93c6d4f..c9155b6b580 100644
--- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h
+++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h
@@ -8,7 +8,6 @@
#include "platform/api/trace_logging_platform.h"
namespace openscreen {
-namespace platform {
class TextTraceLoggingPlatform : public TraceLoggingPlatform {
public:
@@ -38,7 +37,6 @@ class TextTraceLoggingPlatform : public TraceLoggingPlatform {
Error::Code error) override;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TEXT_TRACE_LOGGING_PLATFORM_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/time.cc b/chromium/third_party/openscreen/src/platform/impl/time.cc
index 5604b84a2f1..6ad93145ec8 100644
--- a/chromium/third_party/openscreen/src/platform/impl/time.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/time.cc
@@ -17,7 +17,6 @@ using std::chrono::steady_clock;
using std::chrono::system_clock;
namespace openscreen {
-namespace platform {
Clock::time_point Clock::now() noexcept {
constexpr bool can_use_steady_clock =
@@ -66,5 +65,4 @@ std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept {
return std::chrono::seconds(since_epoch);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc
index d784dc591ac..321bb7b32b5 100644
--- a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc
@@ -11,7 +11,6 @@
using std::chrono::seconds;
namespace openscreen {
-namespace platform {
#if __cplusplus < 202000L
// Before C++20, the standard does not guarantee that time_t is the number of
@@ -51,5 +50,4 @@ TEST(TimeTest, TimeTMeetsTheCpp20Standard) {
}
#endif
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc
index dd16532421a..28c25ffe288 100644
--- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc
@@ -7,7 +7,6 @@
#include <chrono>
namespace openscreen {
-namespace platform {
struct timeval ToTimeval(const Clock::duration& timeout) {
struct timeval tv;
@@ -21,5 +20,4 @@ struct timeval ToTimeval(const Clock::duration& timeout) {
return tv;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h
index 3679682de07..8f42129ab26 100644
--- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.h
@@ -10,11 +10,9 @@
#include "platform/api/time.h"
namespace openscreen {
-namespace platform {
struct timeval ToTimeval(const Clock::duration& timeout);
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TIMEVAL_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc
index 3b679366c96..db82d7ad9c1 100644
--- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix_unittest.cc
@@ -7,7 +7,6 @@
#include "gtest/gtest.h"
namespace openscreen {
-namespace platform {
TEST(TimevalPosixTest, ToTimeval) {
auto timespan = Clock::duration::zero();
@@ -36,5 +35,4 @@ TEST(TimevalPosixTest, ToTimeval) {
EXPECT_EQ(timeval.tv_usec, 10);
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc
index 8b270dabffa..c9749f728a2 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.cc
@@ -29,13 +29,13 @@
#include "util/trace_logging.h"
namespace openscreen {
-namespace platform {
namespace {
ErrorOr<std::vector<uint8_t>> GetDEREncodedPeerCertificate(const SSL& ssl) {
X509* const peer_cert = SSL_get_peer_certificate(&ssl);
- ErrorOr<std::vector<uint8_t>> der_peer_cert = ExportCertificate(*peer_cert);
+ ErrorOr<std::vector<uint8_t>> der_peer_cert =
+ ExportX509CertificateToDer(*peer_cert);
X509_free(peer_cert);
return der_peer_cert;
}
@@ -60,16 +60,19 @@ TlsConnectionFactoryPosix::TlsConnectionFactoryPosix(
OSP_DCHECK(task_runner_);
}
-TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() = default;
+TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+}
// TODO(rwkeane): Add support for resuming sessions.
// TODO(rwkeane): Integrate with Auth.
void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
const TlsConnectOptions& options) {
- TRACE_SCOPED(TraceCategory::SSL, "TlsConnectionFactoryPosix::Connect");
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+ TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::Connect");
IPAddress::Version version = remote_address.address.version();
std::unique_ptr<TlsConnectionPosix> connection(
- new TlsConnectionPosix(version, task_runner_, platform_client_));
+ new TlsConnectionPosix(version, task_runner_));
Error connect_error = connection->socket_->Connect(remote_address);
if (!connect_error.ok()) {
TRACE_SET_RESULT(connect_error);
@@ -89,39 +92,23 @@ void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_PEER, nullptr);
}
- const int connection_status = SSL_connect(connection->ssl_.get());
- if (connection_status != 1) {
- DispatchConnectionFailed(connection->GetRemoteEndpoint());
- TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status));
- return;
- }
-
- ErrorOr<std::vector<uint8_t>> der_peer_cert =
- GetDEREncodedPeerCertificate(*connection->ssl_);
- if (!der_peer_cert) {
- DispatchConnectionFailed(connection->GetRemoteEndpoint());
- TRACE_SET_RESULT(der_peer_cert.error());
- return;
- }
-
- task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
- der = std::move(der_peer_cert.value()),
- moved_connection = std::move(connection)]() mutable {
- if (auto* self = weak_this.get()) {
- self->client_->OnConnected(self, std::move(der),
- std::move(moved_connection));
- }
- });
+ Connect(std::move(connection));
}
void TlsConnectionFactoryPosix::SetListenCredentials(
const TlsCredentials& credentials) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
EnsureInitialized();
ErrorOr<bssl::UniquePtr<X509>> cert = ImportCertificate(
credentials.der_x509_cert.data(), credentials.der_x509_cert.size());
- if (!cert ||
- SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1) {
+ ErrorOr<bssl::UniquePtr<EVP_PKEY>> pkey =
+ ImportRSAPrivateKey(credentials.der_rsa_private_key.data(),
+ credentials.der_rsa_private_key.size());
+
+ if (!cert || !pkey ||
+ SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1 ||
+ SSL_CTX_use_PrivateKey(ssl_context_.get(), pkey.value().get()) != 1) {
DispatchError(Error::Code::kSocketListenFailure);
TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
return;
@@ -132,15 +119,23 @@ void TlsConnectionFactoryPosix::SetListenCredentials(
void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address,
const TlsListenOptions& options) {
+ OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
// Credentials must be set before Listen() is called.
OSP_DCHECK(listen_credentials_set_);
auto socket = std::make_unique<StreamSocketPosix>(local_address);
+ socket->Bind();
socket->Listen(options.backlog_size);
+ if (socket->state() == SocketState::kClosed) {
+ DispatchError(Error::Code::kSocketListenFailure);
+ TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
+ return;
+ }
+ OSP_DCHECK(socket->state() == SocketState::kNotConnected);
OSP_DCHECK(platform_client_);
if (platform_client_) {
- platform_client_->tls_data_router()->RegisterSocketObserver(
+ platform_client_->tls_data_router()->RegisterAcceptObserver(
std::move(socket), this);
}
}
@@ -175,38 +170,16 @@ void TlsConnectionFactoryPosix::OnSocketAccepted(
std::unique_ptr<StreamSocket> socket) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- TRACE_SCOPED(TraceCategory::SSL,
+ TRACE_SCOPED(TraceCategory::kSsl,
"TlsConnectionFactoryPosix::OnSocketAccepted");
- std::unique_ptr<TlsConnectionPosix> connection(new TlsConnectionPosix(
- std::move(socket), task_runner_, platform_client_));
+ std::unique_ptr<TlsConnectionPosix> connection(
+ new TlsConnectionPosix(std::move(socket), task_runner_));
if (!ConfigureSsl(connection.get())) {
return;
}
- const int connection_status = SSL_accept(connection->ssl_.get());
- if (connection_status != 1) {
- DispatchConnectionFailed(connection->GetRemoteEndpoint());
- TRACE_SET_RESULT(GetSSLError(connection->ssl_.get(), connection_status));
- return;
- }
-
- ErrorOr<std::vector<uint8_t>> der_peer_cert =
- GetDEREncodedPeerCertificate(*connection->ssl_);
- if (!der_peer_cert) {
- DispatchConnectionFailed(connection->GetRemoteEndpoint());
- TRACE_SET_RESULT(der_peer_cert.error());
- return;
- }
-
- task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
- der = std::move(der_peer_cert.value()),
- moved_connection = std::move(connection)]() mutable {
- if (auto* self = weak_this.get()) {
- self->client_->OnAccepted(self, std::move(der),
- std::move(moved_connection));
- }
- });
+ Accept(std::move(connection));
}
bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {
@@ -258,6 +231,86 @@ void TlsConnectionFactoryPosix::Initialize() {
ssl_context_.reset(context);
}
+void TlsConnectionFactoryPosix::Connect(
+ std::unique_ptr<TlsConnectionPosix> connection) {
+ OSP_DCHECK(connection->socket_->state() == SocketState::kConnected);
+ const int connection_status = SSL_connect(connection->ssl_.get());
+ if (connection_status != 1) {
+ Error error = GetSSLError(connection->ssl_.get(), connection_status);
+ if (error.code() == Error::Code::kAgain) {
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ conn = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ self->Connect(std::move(conn));
+ }
+ });
+ return;
+ } else {
+ OSP_DVLOG << "SSL_connect failed with error: " << error;
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
+ TRACE_SET_RESULT(error);
+ return;
+ }
+ }
+
+ ErrorOr<std::vector<uint8_t>> der_peer_cert =
+ GetDEREncodedPeerCertificate(*connection->ssl_);
+ if (!der_peer_cert) {
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
+ TRACE_SET_RESULT(der_peer_cert.error());
+ return;
+ }
+
+ connection->RegisterConnectionWithDataRouter(platform_client_);
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ der = std::move(der_peer_cert.value()),
+ moved_connection = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ self->client_->OnConnected(self, std::move(der),
+ std::move(moved_connection));
+ }
+ });
+}
+
+void TlsConnectionFactoryPosix::Accept(
+ std::unique_ptr<TlsConnectionPosix> connection) {
+ OSP_DCHECK(connection->socket_->state() == SocketState::kConnected);
+ const int connection_status = SSL_accept(connection->ssl_.get());
+ if (connection_status != 1) {
+ Error error = GetSSLError(connection->ssl_.get(), connection_status);
+ if (error.code() == Error::Code::kAgain) {
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ conn = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ self->Accept(std::move(conn));
+ }
+ });
+ return;
+ } else {
+ OSP_DVLOG << "SSL_accept failed with error: " << error;
+ DispatchConnectionFailed(connection->GetRemoteEndpoint());
+ TRACE_SET_RESULT(error);
+ return;
+ }
+ }
+
+ ErrorOr<std::vector<uint8_t>> der_peer_cert =
+ GetDEREncodedPeerCertificate(*connection->ssl_);
+ std::vector<uint8_t> der;
+ if (der_peer_cert) {
+ der = std::move(der_peer_cert.value());
+ }
+ connection->RegisterConnectionWithDataRouter(platform_client_);
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ der = std::move(der),
+ moved_connection = std::move(connection)]() mutable {
+ if (auto* self = weak_this.get()) {
+ self->client_->OnAccepted(self, std::move(der),
+ std::move(moved_connection));
+ }
+ });
+}
+
void TlsConnectionFactoryPosix::DispatchConnectionFailed(
const IPEndpoint& remote_endpoint) {
task_runner_->PostTask(
@@ -277,5 +330,4 @@ void TlsConnectionFactoryPosix::DispatchError(Error error) {
});
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h
index 46ba6e25a58..13ac6b6ddab 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_factory_posix.h
@@ -14,10 +14,9 @@
#include "platform/base/error.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/tls_data_router_posix.h"
-#include "platform/impl/weak_ptr.h"
+#include "util/weak_ptr.h"
namespace openscreen {
-namespace platform {
class StreamSocket;
@@ -62,6 +61,11 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory,
// factory.
void Initialize();
+ // Handles their respective SSL handshake calls. These will continue to be
+ // scheduled on |task_runner_| until the handshake completes.
+ void Connect(std::unique_ptr<TlsConnectionPosix> connection);
+ void Accept(std::unique_ptr<TlsConnectionPosix> connection);
+
// Called on any thread, to post a task to notify the Client that a connection
// failure or other error has occurred.
void DispatchConnectionFailed(const IPEndpoint& remote_endpoint);
@@ -86,7 +90,6 @@ class TlsConnectionFactoryPosix : public TlsConnectionFactory,
OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionFactoryPosix);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TLS_CONNECTION_FACTORY_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc
index 3fa39037cd2..424f2ab49cc 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.cc
@@ -27,46 +27,26 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
// TODO(jophba, rwkeane): implement write blocking/unblocking
TlsConnectionPosix::TlsConnectionPosix(IPEndpoint local_address,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client)
+ TaskRunner* task_runner)
: task_runner_(task_runner),
- platform_client_(platform_client),
- socket_(std::make_unique<StreamSocketPosix>(local_address)),
- buffer_(this) {
+ socket_(std::make_unique<StreamSocketPosix>(local_address)) {
OSP_DCHECK(task_runner_);
- if (platform_client_) {
- platform_client_->tls_data_router()->RegisterConnection(this);
- }
}
TlsConnectionPosix::TlsConnectionPosix(IPAddress::Version version,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client)
+ TaskRunner* task_runner)
: task_runner_(task_runner),
- platform_client_(platform_client),
- socket_(std::make_unique<StreamSocketPosix>(version)),
- buffer_(this) {
+ socket_(std::make_unique<StreamSocketPosix>(version)) {
OSP_DCHECK(task_runner_);
- if (platform_client_) {
- platform_client_->tls_data_router()->RegisterConnection(this);
- }
}
TlsConnectionPosix::TlsConnectionPosix(std::unique_ptr<StreamSocket> socket,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client)
- : task_runner_(task_runner),
- platform_client_(platform_client),
- socket_(std::move(socket)),
- buffer_(this) {
+ TaskRunner* task_runner)
+ : task_runner_(task_runner), socket_(std::move(socket)) {
OSP_DCHECK(task_runner_);
- if (platform_client_) {
- platform_client_->tls_data_router()->RegisterConnection(this);
- }
}
TlsConnectionPosix::~TlsConnectionPosix() {
@@ -76,47 +56,43 @@ TlsConnectionPosix::~TlsConnectionPosix() {
}
void TlsConnectionPosix::TryReceiveMessage() {
- const int bytes_available = SSL_pending(ssl_.get());
- if (bytes_available > 0) {
- // NOTE: the pending size of the data block available is not a guarantee
- // that it will receive only bytes_available or even
- // any data, since not all pending bytes are application data.
- std::vector<uint8_t> block(bytes_available);
-
- const int bytes_read = SSL_read(ssl_.get(), block.data(), bytes_available);
-
- // Read operator was not successful, either due to a closed connection,
- // an error occurred, or we have to take an action.
- if (bytes_read <= 0) {
- const Error error = GetSSLError(ssl_.get(), bytes_read);
- if (!error.ok() && (error != Error::Code::kAgain)) {
- DispatchError(error);
- }
- return;
+ OSP_DCHECK(ssl_);
+ constexpr int kMaxApplicationDataBytes = 4096;
+ std::vector<uint8_t> block(kMaxApplicationDataBytes);
+ const int bytes_read =
+ SSL_read(ssl_.get(), block.data(), kMaxApplicationDataBytes);
+
+ // Read operator was not successful, either due to a closed connection,
+ // no application data available, an error occurred, or we have to take an
+ // action.
+ if (bytes_read <= 0) {
+ const Error error = GetSSLError(ssl_.get(), bytes_read);
+ if (!error.ok() && (error != Error::Code::kAgain)) {
+ DispatchError(error);
}
+ return;
+ }
- block.resize(bytes_read);
+ block.resize(bytes_read);
- task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
- moved_block = std::move(block)]() mutable {
- if (auto* self = weak_this.get()) {
- if (auto* client = self->client_) {
- client->OnRead(self, std::move(moved_block));
- }
+ task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
+ moved_block = std::move(block)]() mutable {
+ if (auto* self = weak_this.get()) {
+ if (auto* client = self->client_) {
+ client->OnRead(self, std::move(moved_block));
}
- });
- }
+ }
+ });
}
void TlsConnectionPosix::SetClient(Client* client) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
client_ = client;
- notified_client_buffer_is_blocked_ = false;
}
-void TlsConnectionPosix::Write(const void* data, size_t len) {
+bool TlsConnectionPosix::Send(const void* data, size_t len) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- buffer_.Write(data, len);
+ return buffer_.Push(data, len);
}
IPEndpoint TlsConnectionPosix::GetLocalEndpoint() const {
@@ -135,29 +111,11 @@ IPEndpoint TlsConnectionPosix::GetRemoteEndpoint() const {
return endpoint.value();
}
-void TlsConnectionPosix::NotifyWriteBufferFill(double fraction) {
- // WARNING: This method is called on multiple threads.
- //
- // The following is very subtle/complex behavior: Only "writes" can increase
- // the buffer fill, so we expect transitions into the "blocked" state to occur
- // on the |task_runner_| thread, and |client_| will be notified
- // *synchronously* when that happens. Likewise, only "reads" can cause
- // transitions to the "unblocked" state; but these will not occur on the
- // |task_runner_| thread. Thus, when unblocking, the |client_| will be
- // notified *asynchronously*; but, that should be acceptable because it's only
- // a race towards a buffer overrun that is of concern.
- //
- // TODO(rwkeane): Have Write() return a bool, and then none of this is needed.
- constexpr double kBlockBufferPercentage = 0.5;
- if (fraction > kBlockBufferPercentage &&
- !notified_client_buffer_is_blocked_) {
- NotifyClientOfWriteBlockStatusSequentially(true);
- } else if (fraction < kBlockBufferPercentage &&
- notified_client_buffer_is_blocked_) {
- NotifyClientOfWriteBlockStatusSequentially(false);
- } else if (fraction >= 0.99) {
- DispatchError(Error::Code::kInsufficientBuffer);
- }
+void TlsConnectionPosix::RegisterConnectionWithDataRouter(
+ PlatformClientPosix* platform_client) {
+ OSP_DCHECK(!platform_client_);
+ platform_client_ = platform_client;
+ platform_client_->tls_data_router()->RegisterConnection(this);
}
void TlsConnectionPosix::SendAvailableBytes() {
@@ -189,38 +147,4 @@ void TlsConnectionPosix::DispatchError(Error error) {
});
}
-void TlsConnectionPosix::NotifyClientOfWriteBlockStatusSequentially(
- bool is_blocked) {
- if (!task_runner_->IsRunningOnTaskRunner()) {
- task_runner_->PostTask(
- [weak_this = weak_factory_.GetWeakPtr(), is_blocked = is_blocked] {
- if (auto* self = weak_this.get()) {
- OSP_DCHECK(self->task_runner_->IsRunningOnTaskRunner());
- self->NotifyClientOfWriteBlockStatusSequentially(is_blocked);
- }
- });
- return;
- }
-
- OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
-
- if (!client_) {
- return;
- }
-
- // Check again, now that the block/unblock state change is happening
- // in-sequence (it originated from parallel executions).
- if (notified_client_buffer_is_blocked_ == is_blocked) {
- return;
- }
-
- notified_client_buffer_is_blocked_ = is_blocked;
- if (is_blocked) {
- client_->OnWriteBlocked(this);
- } else {
- client_->OnWriteUnblocked(this);
- }
-}
-
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h
index c0089844163..c78bf5f14d3 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_connection_posix.h
@@ -7,23 +7,20 @@
#include <openssl/ssl.h>
-#include <atomic>
#include <memory>
#include "platform/api/tls_connection.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/stream_socket_posix.h"
#include "platform/impl/tls_write_buffer.h"
-#include "platform/impl/weak_ptr.h"
+#include "util/weak_ptr.h"
namespace openscreen {
-namespace platform {
class TaskRunner;
class TlsConnectionFactoryPosix;
-class TlsConnectionPosix : public TlsConnection,
- public TlsWriteBuffer::Observer {
+class TlsConnectionPosix : public TlsConnection {
public:
~TlsConnectionPosix() override;
@@ -36,49 +33,37 @@ class TlsConnectionPosix : public TlsConnection,
// TlsConnection overrides.
void SetClient(Client* client) override;
- void Write(const void* data, size_t len) override;
+ bool Send(const void* data, size_t len) override;
IPEndpoint GetLocalEndpoint() const override;
IPEndpoint GetRemoteEndpoint() const override;
- // TlsWriteBuffer::Observer overrides.
- void NotifyWriteBufferFill(double fraction) override;
+ // Registers |this| with the platform TlsDataRouterPosix. This is called
+ // automatically by TlsConnectionFactoryPosix after the handshake completes.
+ void RegisterConnectionWithDataRouter(PlatformClientPosix* platform_client);
+
+ const SocketHandle& socket_handle() const { return socket_->socket_handle(); }
protected:
friend class TlsConnectionFactoryPosix;
- TlsConnectionPosix(IPEndpoint local_address,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client =
- PlatformClientPosix::GetInstance());
- TlsConnectionPosix(IPAddress::Version version,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client =
- PlatformClientPosix::GetInstance());
+ TlsConnectionPosix(IPEndpoint local_address, TaskRunner* task_runner);
+ TlsConnectionPosix(IPAddress::Version version, TaskRunner* task_runner);
TlsConnectionPosix(std::unique_ptr<StreamSocket> socket,
- TaskRunner* task_runner,
- PlatformClientPosix* platform_client =
- PlatformClientPosix::GetInstance());
+ TaskRunner* task_runner);
private:
// Called on any thread, to post a task to notify the Client that an |error|
// has occurred.
void DispatchError(Error error);
- // Helper to call OnWriteBlocked() or OnWriteUnblocked(). If this is not
- // called within a task run by |task_runner_|, it trampolines by posting a
- // task to call itself back via |task_runner_|. See comments in implementation
- // of NotifyWriteBufferFill() for further details.
- void NotifyClientOfWriteBlockStatusSequentially(bool is_blocked);
-
TaskRunner* const task_runner_;
- PlatformClientPosix* const platform_client_;
+ PlatformClientPosix* platform_client_ = nullptr;
Client* client_ = nullptr;
std::unique_ptr<StreamSocket> socket_;
bssl::UniquePtr<SSL> ssl_;
- std::atomic_bool notified_client_buffer_is_blocked_{false};
TlsWriteBuffer buffer_;
WeakPtrFactory<TlsConnectionPosix> weak_factory_{this};
@@ -86,7 +71,6 @@ class TlsConnectionPosix : public TlsConnection,
OSP_DISALLOW_COPY_AND_ASSIGN(TlsConnectionPosix);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TLS_CONNECTION_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc
index 642d4a93ee0..288f1998aec 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.cc
@@ -9,7 +9,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
TlsDataRouterPosix::TlsDataRouterPosix(
SocketHandleWaiter* waiter,
@@ -21,30 +20,46 @@ TlsDataRouterPosix::~TlsDataRouterPosix() {
}
void TlsDataRouterPosix::RegisterConnection(TlsConnectionPosix* connection) {
- // TODO(jophba, rwkeane): implement this method.
- OSP_UNIMPLEMENTED();
+ {
+ std::lock_guard<std::mutex> lock(connections_mutex_);
+ OSP_DCHECK(std::find(connections_.begin(), connections_.end(),
+ connection) == connections_.end());
+ connections_.push_back(connection);
+ }
+
+ waiter_->Subscribe(this, connection->socket_handle());
}
void TlsDataRouterPosix::DeregisterConnection(TlsConnectionPosix* connection) {
- // TODO(jophba, rwkeane): implement this method.
- OSP_UNIMPLEMENTED();
+ {
+ std::lock_guard<std::mutex> lock(connections_mutex_);
+ auto it = std::remove_if(
+ connections_.begin(), connections_.end(),
+ [connection](TlsConnectionPosix* conn) { return conn == connection; });
+ if (it == connections_.end()) {
+ return;
+ }
+ connections_.erase(it, connections_.end());
+ }
+
+ waiter_->OnHandleDeletion(this, connection->socket_handle());
}
-void TlsDataRouterPosix::RegisterSocketObserver(
+void TlsDataRouterPosix::RegisterAcceptObserver(
std::unique_ptr<StreamSocketPosix> socket,
SocketObserver* observer) {
OSP_DCHECK(observer);
StreamSocketPosix* socket_ptr = socket.get();
{
- std::unique_lock<std::mutex> lock(socket_mutex_);
- watched_stream_sockets_.push_back(std::move(socket));
- socket_mappings_[socket_ptr] = observer;
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ accept_stream_sockets_.push_back(std::move(socket));
+ accept_socket_mappings_[socket_ptr] = observer;
}
waiter_->Subscribe(this, socket_ptr->socket_handle());
}
-void TlsDataRouterPosix::DeregisterSocketObserver(StreamSocketPosix* socket) {
+void TlsDataRouterPosix::DeregisterAcceptObserver(StreamSocketPosix* socket) {
OnSocketDestroyed(socket, false);
}
@@ -56,8 +71,8 @@ void TlsDataRouterPosix::OnConnectionDestroyed(TlsConnectionPosix* connection) {
void TlsDataRouterPosix::OnSocketDestroyed(StreamSocketPosix* socket,
bool skip_locking_for_testing) {
{
- std::unique_lock<std::mutex> lock(socket_mutex_);
- if (!socket_mappings_.erase(socket)) {
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ if (!accept_socket_mappings_.erase(socket)) {
return;
}
}
@@ -66,74 +81,38 @@ void TlsDataRouterPosix::OnSocketDestroyed(StreamSocketPosix* socket,
skip_locking_for_testing);
{
- std::unique_lock<std::mutex> lock(socket_mutex_);
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
auto it = std::find_if(
- watched_stream_sockets_.begin(), watched_stream_sockets_.end(),
+ accept_stream_sockets_.begin(), accept_stream_sockets_.end(),
[socket](const std::unique_ptr<StreamSocketPosix>& ptr) {
return ptr.get() == socket;
});
- OSP_DCHECK(it != watched_stream_sockets_.end());
- watched_stream_sockets_.erase(it);
+ OSP_DCHECK(it != accept_stream_sockets_.end());
+ accept_stream_sockets_.erase(it);
}
}
void TlsDataRouterPosix::ProcessReadyHandle(
SocketHandleWaiter::SocketHandleRef handle) {
- std::unique_lock<std::mutex> lock(socket_mutex_);
- for (const auto& pair : socket_mappings_) {
- if (pair.first->socket_handle() == handle) {
- pair.second->OnConnectionPending(pair.first);
- break;
+ {
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ for (const auto& pair : accept_socket_mappings_) {
+ if (pair.first->socket_handle() == handle) {
+ pair.second->OnConnectionPending(pair.first);
+ return;
+ }
}
}
-}
-
-void TlsDataRouterPosix::PerformNetworkingOperations(Clock::duration timeout) {
- Clock::time_point start_time = now_function_();
-
- // TODO(rwkeane): Minimize time locked based on how RegisterConnection and
- // DeregisterConnection are implimented.
- std::lock_guard<std::mutex> lock(connections_mutex_);
- if (connections_.empty()) {
- return;
- }
-
- NetworkingOperation current_operation = last_operation_;
- std::vector<TlsConnectionPosix*>::iterator current_connection =
- GetConnection(last_connection_processed_);
- last_connection_processed_ = *current_connection;
- do {
- // Get the next (connection, mode) pair to use for processing. This allows
- // for processing in a round-robin fashion, such that this call to
- // PerformNetworkingOperations() will pick up where the previous call left
- // off.
- current_operation = GetNextMode(current_operation);
- if (current_operation == NetworkingOperation::kReading) {
- current_connection++;
- if (current_connection == connections_.end()) {
- current_connection = connections_.begin();
+ {
+ std::lock_guard<std::mutex> lock(connections_mutex_);
+ for (TlsConnectionPosix* connection : connections_) {
+ if (connection->socket_handle() == handle) {
+ connection->TryReceiveMessage();
+ connection->SendAvailableBytes();
+ return;
}
}
-
- // Process the (connection, mode).
- switch (current_operation) {
- case NetworkingOperation::kReading:
- (*current_connection)->TryReceiveMessage();
- break;
- case NetworkingOperation::kWriting:
- (*current_connection)->SendAvailableBytes();
- break;
- }
-
- // If this (connection, mode) is where we started, exit.
- if (last_connection_processed_ == *current_connection &&
- last_operation_ == current_operation) {
- break;
- }
- } while (!HasTimedOut(start_time, timeout));
-
- last_connection_processed_ = *current_connection;
- last_operation_ = current_operation;
+ }
}
bool TlsDataRouterPosix::HasTimedOut(Clock::time_point start_time,
@@ -142,34 +121,16 @@ bool TlsDataRouterPosix::HasTimedOut(Clock::time_point start_time,
}
void TlsDataRouterPosix::RemoveWatchedSocket(StreamSocketPosix* socket) {
- std::unique_lock<std::mutex> lock(socket_mutex_);
- const auto it = socket_mappings_.find(socket);
- if (it != socket_mappings_.end()) {
- socket_mappings_.erase(it);
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ const auto it = accept_socket_mappings_.find(socket);
+ if (it != accept_socket_mappings_.end()) {
+ accept_socket_mappings_.erase(it);
}
}
bool TlsDataRouterPosix::IsSocketWatched(StreamSocketPosix* socket) const {
- std::unique_lock<std::mutex> lock(socket_mutex_);
- return socket_mappings_.find(socket) != socket_mappings_.end();
-}
-
-TlsDataRouterPosix::NetworkingOperation TlsDataRouterPosix::GetNextMode(
- NetworkingOperation state) {
- return state == NetworkingOperation::kReading ? NetworkingOperation::kWriting
- : NetworkingOperation::kReading;
-}
-
-std::vector<TlsConnectionPosix*>::iterator TlsDataRouterPosix::GetConnection(
- TlsConnectionPosix* current) {
- auto current_pos =
- std::find(connections_.begin(), connections_.end(), current);
- if (current_pos == connections_.end()) {
- return connections_.begin();
- }
-
- return current_pos;
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ return accept_socket_mappings_.find(socket) != accept_socket_mappings_.end();
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h
index fb8ac45c0d9..2549bd2e5e2 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix.h
@@ -14,7 +14,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
class StreamSocketPosix;
class TlsConnectionPosix;
@@ -57,24 +56,19 @@ class TlsDataRouterPosix : public SocketHandleWaiter::Subscriber {
// Deregister a TlsConnection.
void DeregisterConnection(TlsConnectionPosix* connection);
- // Takes ownership of a StreamSocket and registers that should be watched for
- // incoming Tcp Connections with the SocketHandleWaiter.
- void RegisterSocketObserver(std::unique_ptr<StreamSocketPosix> socket,
+ // Takes ownership of a StreamSocket and registers that it should be watched
+ // for incoming TCP connections with the SocketHandleWaiter.
+ void RegisterAcceptObserver(std::unique_ptr<StreamSocketPosix> socket,
SocketObserver* observer);
- // Stops watching a Tcp Connections for incoming connections.
+ // Stops watching a TCP socket for incoming connections.
// NOTE: This will destroy the StreamSocket.
- virtual void DeregisterSocketObserver(StreamSocketPosix* socket);
+ virtual void DeregisterAcceptObserver(StreamSocketPosix* socket);
// Method to be executed on TlsConnection destruction. This is expected to
// block until the networking thread is not using the provided connection.
void OnConnectionDestroyed(TlsConnectionPosix* connection);
- // Performs Read and Write operations for all connections or until the timeout
- // has been hit, whichever is first. In the latter case, the following
- // iteration will continue from wherever the previous iteration left off.
- void PerformNetworkingOperations(Clock::duration timeout);
-
// SocketHandleWaiter::Subscriber overrides.
void ProcessReadyHandle(SocketHandleWaiter::SocketHandleRef handle) override;
@@ -91,48 +85,35 @@ class TlsDataRouterPosix : public SocketHandleWaiter::Subscriber {
friend class TestingDataRouter;
private:
- enum class NetworkingOperation { kReading, kWriting };
-
void OnSocketDestroyed(StreamSocketPosix* socket,
bool skip_locking_for_testing);
void RemoveWatchedSocket(StreamSocketPosix* socket);
- // Helper methods for PerformNetworkingOperations.
- NetworkingOperation GetNextMode(NetworkingOperation state);
- std::vector<TlsConnectionPosix*>::iterator GetConnection(
- TlsConnectionPosix* current);
-
SocketHandleWaiter* waiter_;
// Mutex guarding connections_ vector.
mutable std::mutex connections_mutex_;
- // Mutex guarding socket_mappings_.
- mutable std::mutex socket_mutex_;
+ // Mutex guarding |accept_socket_mappings_|.
+ mutable std::mutex accept_socket_mutex_;
// Function to get the current time.
std::function<Clock::time_point()> now_function_;
- // Information related to how much of PerformNetworkingOperations(...) was
- // completed before hitting the timeout.
- NetworkingOperation last_operation_ = NetworkingOperation::kReading;
- TlsConnectionPosix* last_connection_processed_ = nullptr;
-
// Mapping from all sockets to the observer that should be called when the
// socket recognizes an incoming connection.
- std::unordered_map<StreamSocketPosix*, SocketObserver*> socket_mappings_
- GUARDED_BY(socket_mutex_);
+ std::unordered_map<StreamSocketPosix*, SocketObserver*>
+ accept_socket_mappings_ GUARDED_BY(accept_socket_mutex_);
// Set of all TlsConnectionPosix objects currently registered.
std::vector<TlsConnectionPosix*> connections_ GUARDED_BY(connections_mutex_);
// StreamSockets currently owned by this object, being watched for
- std::vector<std::unique_ptr<StreamSocketPosix>> watched_stream_sockets_
- GUARDED_BY(connections_mutex_);
+ std::vector<std::unique_ptr<StreamSocketPosix>> accept_stream_sockets_
+ GUARDED_BY(accept_socket_mutex_);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TLS_DATA_ROUTER_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc
index 257ad08b3c8..f242900928e 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc
@@ -4,6 +4,9 @@
#include "platform/impl/tls_data_router_posix.h"
+#include <memory>
+#include <utility>
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/base/ip_address.h"
@@ -13,48 +16,51 @@
#include "platform/test/fake_task_runner.h"
namespace openscreen {
-namespace platform {
+namespace {
-using testing::_;
-using testing::Return;
+class MockNetworkWaiter final : public SocketHandleWaiter {
+ public:
+ MockNetworkWaiter() : SocketHandleWaiter(&FakeClock::now) {}
-class TestingDataRouter : public TlsDataRouterPosix {
+ MOCK_METHOD2(
+ AwaitSocketsReadable,
+ ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&,
+ const Clock::duration&));
+};
+
+class MockSocket : public StreamSocketPosix {
public:
- TestingDataRouter(SocketHandleWaiter* waiter) : TlsDataRouterPosix(waiter) {}
+ MockSocket(int fd) : StreamSocketPosix(IPAddress::Version::kV4), handle(fd) {}
- using TlsDataRouterPosix::IsSocketWatched;
- using TlsDataRouterPosix::NetworkingOperation;
+ const SocketHandle& socket_handle() const override { return handle; }
- void DeregisterSocketObserver(StreamSocketPosix* socket) override {
- TlsDataRouterPosix::OnSocketDestroyed(socket, true);
- }
+ SocketHandle handle;
+};
- bool AnySocketsWatched() {
- std::unique_lock<std::mutex> lock(socket_mutex_);
- return !watched_stream_sockets_.empty() && !socket_mappings_.empty();
- }
+class MockConnection : public TlsConnectionPosix {
+ public:
+ explicit MockConnection(int fd, TaskRunner* task_runner)
+ : TlsConnectionPosix(std::make_unique<MockSocket>(fd), task_runner) {}
+ MOCK_METHOD0(SendAvailableBytes, void());
+ MOCK_METHOD0(TryReceiveMessage, void());
+};
- MOCK_METHOD2(HasTimedOut, bool(Clock::time_point, Clock::duration));
+} // namespace
- void SetLastState(NetworkingOperation last_operation) {
- last_operation_ = last_operation;
- }
+class TestingDataRouter : public TlsDataRouterPosix {
+ public:
+ explicit TestingDataRouter(SocketHandleWaiter* waiter)
+ : TlsDataRouterPosix(waiter) {}
- void SetLastConnection(TlsConnectionPosix* last_connection) {
- last_connection_processed_ = last_connection;
- }
+ using TlsDataRouterPosix::IsSocketWatched;
- // TODO(rwkeane): Remove these methods once RegisterConnection and
- // DeregisterConnection are implimented in TlsDataRouterPosix.
- void AddConnectionForTesting(TlsConnectionPosix* connection) {
- std::lock_guard<std::mutex> lock(connections_mutex_);
- connections_.push_back(connection);
+ void DeregisterAcceptObserver(StreamSocketPosix* socket) override {
+ TlsDataRouterPosix::OnSocketDestroyed(socket, true);
}
- void RemoveConnectionForTesting(TlsConnectionPosix* connection) {
- std::lock_guard<std::mutex> lock(connections_mutex_);
- auto it = std::find(connections_.begin(), connections_.end(), connection);
- connections_.erase(it);
+ bool AnySocketsWatched() {
+ std::unique_lock<std::mutex> lock(accept_socket_mutex_);
+ return !accept_stream_sockets_.empty() && !accept_socket_mappings_.empty();
}
};
@@ -62,135 +68,66 @@ class MockObserver : public TestingDataRouter::SocketObserver {
MOCK_METHOD1(OnConnectionPending, void(StreamSocketPosix*));
};
-class MockNetworkWaiter final : public SocketHandleWaiter {
+class TlsNetworkingManagerPosixTest : public testing::Test {
public:
- MOCK_METHOD2(
- AwaitSocketsReadable,
- ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&,
- const Clock::duration&));
-};
+ TlsNetworkingManagerPosixTest()
+ : clock_(Clock::now()),
+ task_runner_(&clock_),
+ network_manager_(&network_waiter_) {}
-class MockConnection : public TlsConnectionPosix {
- public:
- MockConnection()
- : TlsConnectionPosix(IPAddress::Version::kV4, &task_runner),
- clock(Clock::now()),
- task_runner(&clock) {}
- MOCK_METHOD0(SendAvailableBytes, void());
- MOCK_METHOD0(TryReceiveMessage, void());
+ FakeTaskRunner* task_runner() { return &task_runner_; }
+ TestingDataRouter* network_manager() { return &network_manager_; }
private:
- FakeClock clock;
- FakeTaskRunner task_runner;
+ FakeClock clock_;
+ FakeTaskRunner task_runner_;
+ MockNetworkWaiter network_waiter_;
+ TestingDataRouter network_manager_;
};
-TEST(TlsNetworkingManagerPosixTest, SocketsWatchedCorrectly) {
- MockNetworkWaiter network_waiter;
- TestingDataRouter network_manager(&network_waiter);
+TEST_F(TlsNetworkingManagerPosixTest, SocketsWatchedCorrectly) {
auto socket = std::make_unique<StreamSocketPosix>(IPAddress::Version::kV4);
MockObserver observer;
auto* ptr = socket.get();
- ASSERT_FALSE(network_manager.IsSocketWatched(ptr));
+ ASSERT_FALSE(network_manager()->IsSocketWatched(ptr));
- network_manager.RegisterSocketObserver(std::move(socket), &observer);
- ASSERT_TRUE(network_manager.IsSocketWatched(ptr));
- ASSERT_TRUE(network_manager.AnySocketsWatched());
+ network_manager()->RegisterAcceptObserver(std::move(socket), &observer);
+ ASSERT_TRUE(network_manager()->IsSocketWatched(ptr));
+ ASSERT_TRUE(network_manager()->AnySocketsWatched());
- network_manager.DeregisterSocketObserver(ptr);
- ASSERT_FALSE(network_manager.IsSocketWatched(ptr));
- ASSERT_FALSE(network_manager.AnySocketsWatched());
+ network_manager()->DeregisterAcceptObserver(ptr);
+ ASSERT_FALSE(network_manager()->IsSocketWatched(ptr));
+ ASSERT_FALSE(network_manager()->AnySocketsWatched());
- network_manager.DeregisterSocketObserver(ptr);
- ASSERT_FALSE(network_manager.IsSocketWatched(ptr));
- ASSERT_FALSE(network_manager.AnySocketsWatched());
+ network_manager()->DeregisterAcceptObserver(ptr);
+ ASSERT_FALSE(network_manager()->IsSocketWatched(ptr));
+ ASSERT_FALSE(network_manager()->AnySocketsWatched());
}
-TEST(TlsNetworkingManagerPosixTest, ExitsAfterOneCall) {
- MockNetworkWaiter network_waiter;
- TestingDataRouter network_manager(&network_waiter);
- MockConnection connection1;
- MockConnection connection2;
- MockConnection connection3;
- network_manager.AddConnectionForTesting(&connection1);
- network_manager.AddConnectionForTesting(&connection2);
- network_manager.AddConnectionForTesting(&connection3);
-
- EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillOnce(Return(true));
+TEST_F(TlsNetworkingManagerPosixTest, CallsReadySocket) {
+ MockConnection connection1(1, task_runner());
+ MockConnection connection2(2, task_runner());
+ MockConnection connection3(3, task_runner());
+ network_manager()->RegisterConnection(&connection1);
+ network_manager()->RegisterConnection(&connection2);
+ network_manager()->RegisterConnection(&connection3);
+
EXPECT_CALL(connection1, SendAvailableBytes()).Times(1);
- EXPECT_CALL(connection1, TryReceiveMessage()).Times(0);
+ EXPECT_CALL(connection1, TryReceiveMessage()).Times(1);
EXPECT_CALL(connection2, SendAvailableBytes()).Times(0);
EXPECT_CALL(connection2, TryReceiveMessage()).Times(0);
EXPECT_CALL(connection3, SendAvailableBytes()).Times(0);
EXPECT_CALL(connection3, TryReceiveMessage()).Times(0);
- network_manager.PerformNetworkingOperations(Clock::duration{0});
-}
+ network_manager()->ProcessReadyHandle(connection1.socket_handle());
-TEST(TlsNetworkingManagerPosixTest, StartsAfterPrevious) {
- MockNetworkWaiter network_waiter;
- TestingDataRouter network_manager(&network_waiter);
- MockConnection connection1;
- MockConnection connection2;
- MockConnection connection3;
- network_manager.AddConnectionForTesting(&connection1);
- network_manager.AddConnectionForTesting(&connection2);
- network_manager.AddConnectionForTesting(&connection3);
- network_manager.SetLastState(
- TestingDataRouter::NetworkingOperation::kReading);
- network_manager.SetLastConnection(&connection2);
-
- EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillOnce(Return(true));
- EXPECT_CALL(connection1, TryReceiveMessage()).Times(0);
EXPECT_CALL(connection1, SendAvailableBytes()).Times(0);
- EXPECT_CALL(connection2, TryReceiveMessage()).Times(0);
- EXPECT_CALL(connection2, SendAvailableBytes()).Times(1);
- EXPECT_CALL(connection3, TryReceiveMessage()).Times(0);
- EXPECT_CALL(connection3, SendAvailableBytes()).Times(0);
- network_manager.PerformNetworkingOperations(Clock::duration{0});
-}
-
-TEST(TlsNetworkingManagerPosixTest, HitsAllCallsOnce) {
- MockNetworkWaiter network_waiter;
- TestingDataRouter network_manager(&network_waiter);
- MockConnection connection1;
- MockConnection connection2;
- MockConnection connection3;
- network_manager.AddConnectionForTesting(&connection1);
- network_manager.AddConnectionForTesting(&connection2);
- network_manager.AddConnectionForTesting(&connection3);
-
- EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillRepeatedly(Return(false));
- EXPECT_CALL(connection1, TryReceiveMessage()).Times(1);
- EXPECT_CALL(connection1, SendAvailableBytes()).Times(1);
- EXPECT_CALL(connection2, TryReceiveMessage()).Times(1);
+ EXPECT_CALL(connection1, TryReceiveMessage()).Times(0);
EXPECT_CALL(connection2, SendAvailableBytes()).Times(1);
- EXPECT_CALL(connection3, TryReceiveMessage()).Times(1);
- EXPECT_CALL(connection3, SendAvailableBytes()).Times(1);
- network_manager.PerformNetworkingOperations(Clock::duration{0});
-}
-
-TEST(TlsNetworkingManagerPosixTest, HitsAllCallsOnceStartedInMiddle) {
- MockNetworkWaiter network_waiter;
- TestingDataRouter network_manager(&network_waiter);
- MockConnection connection1;
- MockConnection connection2;
- MockConnection connection3;
- network_manager.AddConnectionForTesting(&connection1);
- network_manager.AddConnectionForTesting(&connection2);
- network_manager.AddConnectionForTesting(&connection3);
- network_manager.SetLastState(
- TestingDataRouter::NetworkingOperation::kReading);
- network_manager.SetLastConnection(&connection2);
-
- EXPECT_CALL(network_manager, HasTimedOut(_, _)).WillRepeatedly(Return(false));
- EXPECT_CALL(connection1, TryReceiveMessage()).Times(1);
- EXPECT_CALL(connection1, SendAvailableBytes()).Times(1);
EXPECT_CALL(connection2, TryReceiveMessage()).Times(1);
- EXPECT_CALL(connection2, SendAvailableBytes()).Times(1);
- EXPECT_CALL(connection3, TryReceiveMessage()).Times(1);
- EXPECT_CALL(connection3, SendAvailableBytes()).Times(1);
- network_manager.PerformNetworkingOperations(Clock::duration{0});
+ EXPECT_CALL(connection3, SendAvailableBytes()).Times(0);
+ EXPECT_CALL(connection3, TryReceiveMessage()).Times(0);
+ network_manager()->ProcessReadyHandle(connection2.socket_handle());
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc
index f0370868906..b91891c36a9 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.cc
@@ -11,16 +11,11 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
+TlsWriteBuffer::TlsWriteBuffer() = default;
TlsWriteBuffer::~TlsWriteBuffer() = default;
-TlsWriteBuffer::TlsWriteBuffer(TlsWriteBuffer::Observer* observer)
- : observer_(observer) {
- OSP_DCHECK(observer_);
-}
-
-size_t TlsWriteBuffer::Write(const void* data, size_t len) {
+bool TlsWriteBuffer::Push(const void* data, size_t len) {
const size_t currently_written_bytes =
bytes_written_so_far_.load(std::memory_order_relaxed);
const size_t current_read_bytes =
@@ -29,34 +24,32 @@ size_t TlsWriteBuffer::Write(const void* data, size_t len) {
// Calculates the current size of the buffer.
const size_t bytes_currently_used =
currently_written_bytes - current_read_bytes;
+ OSP_DCHECK_LE(bytes_currently_used, kBufferSizeBytes);
+ if ((kBufferSizeBytes - bytes_currently_used) < len) {
+ return false;
+ }
- // Calculates how many bytes out of the requested |len| can be written without
- // causing a buffer overflow.
- const size_t write_len =
- std::min(kBufferSizeBytes - bytes_currently_used, len);
-
- // Calculates the number of bytes out of |write_len| to write in the first
- // memcpy operation, which is either all |write_len| or the number that can be
- // written before wrapping around to the beginning of the underlying array.
+ // Calculates the number of bytes out of |len| to write in the first memcpy
+ // operation, which is either all of |len| or the number that can be written
+ // before wrapping around to the beginning of the underlying array.
const size_t current_write_index = currently_written_bytes % kBufferSizeBytes;
const size_t first_write_len =
- std::min(write_len, kBufferSizeBytes - current_write_index);
+ std::min(len, kBufferSizeBytes - current_write_index);
memcpy(&buffer_[current_write_index], data, first_write_len);
- // If we didn't write all |write_len| bytes in the previous memcpy, copy any
+ // If fewer than |len| bytes were transferred in the previous memcpy, copy any
// remaining bytes to the array, starting at 0 (since the last write must have
// finished at the end of the array).
- if (first_write_len != write_len) {
+ if (first_write_len != len) {
const uint8_t* new_start =
static_cast<const uint8_t*>(data) + first_write_len;
- memcpy(buffer_, new_start, write_len - first_write_len);
+ memcpy(buffer_, new_start, len - first_write_len);
}
// Store and return updated values.
- const size_t new_write_size = currently_written_bytes + write_len;
+ const size_t new_write_size = currently_written_bytes + len;
bytes_written_so_far_.store(new_write_size, std::memory_order_release);
- NotifyWriteBufferFill(new_write_size, bytes_read_so_far_);
- return write_len;
+ return true;
}
absl::Span<const uint8_t> TlsWriteBuffer::GetReadableRegion() {
@@ -72,7 +65,7 @@ absl::Span<const uint8_t> TlsWriteBuffer::GetReadableRegion() {
// this additional level of complexity.
const size_t avail = currently_written_bytes - current_read_bytes;
const size_t begin = current_read_bytes % kBufferSizeBytes;
- const size_t end = std::min(begin + avail, kBufferSizeBytes - 1);
+ const size_t end = std::min(begin + avail, kBufferSizeBytes);
return absl::Span<const uint8_t>(&buffer_[begin], end - begin);
}
@@ -85,16 +78,9 @@ void TlsWriteBuffer::Consume(size_t byte_count) {
OSP_DCHECK_GE(currently_written_bytes - current_read_bytes, byte_count);
const size_t new_read_index = current_read_bytes + byte_count;
bytes_read_so_far_.store(new_read_index, std::memory_order_release);
-
- NotifyWriteBufferFill(currently_written_bytes, new_read_index);
}
-void TlsWriteBuffer::NotifyWriteBufferFill(size_t write_index,
- size_t read_index) {
- double fraction =
- static_cast<double>(write_index - read_index) / kBufferSizeBytes;
- observer_->NotifyWriteBufferFill(fraction);
-}
+// static
+constexpr size_t TlsWriteBuffer::kBufferSizeBytes;
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h
index 308b882dfc8..99ac45ac788 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer.h
@@ -12,31 +12,20 @@
#include "platform/base/macros.h"
namespace openscreen {
-namespace platform {
// This class is responsible for buffering TLS Write data. The approach taken by
// this class is to allow for a single thread to act as a publisher of data and
// for a separate thread to act as the consumer of that data. The data in
-// question is written to a lockless FIFO queue. Whenever the capacity of the
-// underlying array changes, Observer::NotifyWriteBufferFill() will be called.
+// question is written to a lockless FIFO queue.
class TlsWriteBuffer {
public:
- class Observer {
- public:
- virtual ~Observer() = default;
-
- // Signals that the write buffer has reached some percentage of being
- // filled. NOTE: This method may be called from multiple threads.
- virtual void NotifyWriteBufferFill(double fraction) = 0;
- };
-
- explicit TlsWriteBuffer(Observer* observer);
-
+ TlsWriteBuffer();
~TlsWriteBuffer();
- // Writes as much of the provided data as possible in the buffer, returning
- // the number of bytes written.
- size_t Write(const void* data, size_t len);
+ // Pushes the provided data into the buffer, returning true if successful.
+ // Returns false if there was insufficient space left. Either all or none of
+ // the data is pushed into the buffer.
+ bool Push(const void* data, size_t len);
// Returns a subset of the readable region of data. At time of reading, more
// data may be available for reading than what is represented in this Span.
@@ -46,13 +35,9 @@ class TlsWriteBuffer {
void Consume(size_t byte_count);
// The amount of space to allocate in the buffer.
- static constexpr size_t kBufferSizeBytes = 1 << 20; // 1 MB space.
+ static constexpr size_t kBufferSizeBytes = 1 << 19; // 0.5 MB.
private:
- // Signals that the write buffer has reached some percentage of being filled,
- // as calculated based on the provided write and read indices.
- void NotifyWriteBufferFill(size_t write_index, size_t read_index);
-
// Buffer where data to be written over the TLS connection is stored.
uint8_t buffer_[kBufferSizeBytes];
@@ -63,12 +48,9 @@ class TlsWriteBuffer {
std::atomic_size_t bytes_read_so_far_{0};
std::atomic_size_t bytes_written_so_far_{0};
- Observer* const observer_;
-
OSP_DISALLOW_COPY_AND_ASSIGN(TlsWriteBuffer);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_TLS_WRITE_BUFFER_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc
index 9e81543336e..57326869f72 100644
--- a/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/tls_write_buffer_unittest.cc
@@ -10,75 +10,108 @@
#include "gtest/gtest.h"
namespace openscreen {
-namespace platform {
namespace {
-class MockObserver : public TlsWriteBuffer::Observer {
- public:
- MOCK_METHOD1(NotifyWriteBufferFill, void(double));
-};
-
-} // namespace
-
TEST(TlsWriteBufferTest, CheckBasicFunctionality) {
- MockObserver observer;
- TlsWriteBuffer buffer(&observer);
+ TlsWriteBuffer buffer;
constexpr size_t write_size = TlsWriteBuffer::kBufferSizeBytes / 2;
uint8_t write_buffer[write_size];
std::fill_n(write_buffer, write_size, uint8_t{1});
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.5)).Times(1);
- EXPECT_EQ(buffer.Write(write_buffer, write_size), write_size);
+ EXPECT_TRUE(buffer.Push(write_buffer, write_size));
absl::Span<const uint8_t> readable_data = buffer.GetReadableRegion();
- EXPECT_EQ(readable_data.size(), write_size);
+ ASSERT_EQ(readable_data.size(), write_size);
for (size_t i = 0; i < readable_data.size(); i++) {
EXPECT_EQ(readable_data[i], 1);
}
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1);
buffer.Consume(write_size / 2);
readable_data = buffer.GetReadableRegion();
- EXPECT_EQ(readable_data.size(), write_size / 2);
+ ASSERT_EQ(readable_data.size(), write_size / 2);
for (size_t i = 0; i < readable_data.size(); i++) {
EXPECT_EQ(readable_data[i], 1);
}
- EXPECT_CALL(observer, NotifyWriteBufferFill(0)).Times(1);
buffer.Consume(write_size / 2);
readable_data = buffer.GetReadableRegion();
- EXPECT_EQ(readable_data.size(), size_t{0});
+ ASSERT_EQ(readable_data.size(), size_t{0});
+
+ // Test that the entire buffer can be used.
+ EXPECT_TRUE(buffer.Push(write_buffer, write_size));
+ EXPECT_TRUE(buffer.Push(write_buffer, write_size));
+ // The buffer should be 100% full at this point. Confirm that no more can be
+ // written.
+ EXPECT_FALSE(buffer.Push(write_buffer, write_size));
+ EXPECT_FALSE(buffer.Push(write_buffer, 1));
}
TEST(TlsWriteBufferTest, TestWrapAround) {
- MockObserver observer;
- TlsWriteBuffer buffer(&observer);
- constexpr size_t write_size = TlsWriteBuffer::kBufferSizeBytes;
- uint8_t write_buffer[write_size];
-
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.75)).Times(1);
- constexpr size_t partial_write_size = write_size * 3 / 4;
- EXPECT_EQ(buffer.Write(write_buffer, partial_write_size), partial_write_size);
-
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1);
- buffer.Consume(write_size / 2);
-
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.75)).Times(1);
- EXPECT_EQ(buffer.Write(write_buffer, write_size / 2), write_size / 2);
-
- absl::Span<const uint8_t> readable_data = buffer.GetReadableRegion();
-
- EXPECT_CALL(observer, NotifyWriteBufferFill(0.25)).Times(1);
- buffer.Consume(write_size / 2);
-
- readable_data = buffer.GetReadableRegion();
- EXPECT_EQ(readable_data.size(), write_size / 4);
-
- EXPECT_CALL(observer, NotifyWriteBufferFill(0)).Times(1);
- buffer.Consume(write_size / 4);
+ TlsWriteBuffer buffer;
+ constexpr size_t buffer_size = TlsWriteBuffer::kBufferSizeBytes;
+ uint8_t write_buffer[buffer_size];
+ std::fill_n(write_buffer, buffer_size, uint8_t{1});
+
+ constexpr size_t partial_buffer_size = buffer_size * 3 / 4;
+ EXPECT_TRUE(buffer.Push(write_buffer, partial_buffer_size));
+ // Buffer contents should now be: |111111111111····|
+ auto region = buffer.GetReadableRegion();
+ auto* const buffer_begin = region.data();
+ ASSERT_TRUE(buffer_begin);
+ EXPECT_EQ(region.size(), partial_buffer_size);
+ EXPECT_TRUE(std::all_of(region.begin(), region.end(),
+ [](uint8_t byte) { return byte == 1; }));
+
+ buffer.Consume(buffer_size / 2);
+ // Buffer contents should now be: |········1111····|
+ region = buffer.GetReadableRegion();
+ EXPECT_EQ(region.data(), buffer_begin + buffer_size / 2);
+ EXPECT_EQ(region.size(), buffer_size / 4);
+ EXPECT_TRUE(std::all_of(region.begin(), region.end(),
+ [](uint8_t byte) { return byte == 1; }));
+
+ std::fill_n(write_buffer, buffer_size, uint8_t{2});
+ EXPECT_TRUE(buffer.Push(write_buffer, buffer_size / 2));
+ // Buffer contents should now be: |2222····11112222|
+ // Readable region should just be the end part.
+ region = buffer.GetReadableRegion();
+ EXPECT_EQ(region.data(), buffer_begin + buffer_size / 2);
+ EXPECT_EQ(region.size(), buffer_size / 2);
+ EXPECT_TRUE(std::all_of(region.begin(), region.begin() + buffer_size / 4,
+ [](uint8_t byte) { return byte == 1; }));
+ EXPECT_TRUE(std::all_of(region.begin() + buffer_size / 4, region.end(),
+ [](uint8_t byte) { return byte == 2; }));
+
+ buffer.Consume(buffer_size / 2);
+ // Buffer contents should now be: |2222············|
+ region = buffer.GetReadableRegion();
+ EXPECT_EQ(region.data(), buffer_begin);
+ EXPECT_EQ(region.size(), buffer_size / 4);
+ EXPECT_TRUE(std::all_of(region.begin(), region.end(),
+ [](uint8_t byte) { return byte == 2; }));
+
+ std::fill_n(write_buffer, buffer_size, uint8_t{3});
+ // The following Push() fails (not enough room).
+ EXPECT_FALSE(buffer.Push(write_buffer, buffer_size));
+ // Buffer contents should still be: |2222············|
+ EXPECT_TRUE(buffer.Push(write_buffer, buffer_size * 3 / 4));
+ // Buffer contents should now be: |2222333333333333|
+ EXPECT_FALSE(buffer.Push(write_buffer, buffer_size)); // Not enough room.
+ EXPECT_FALSE(buffer.Push(write_buffer, 1)); // Not enough room.
+ region = buffer.GetReadableRegion();
+ EXPECT_EQ(region.data(), buffer_begin);
+ EXPECT_EQ(region.size(), buffer_size);
+ EXPECT_TRUE(std::all_of(region.begin(), region.begin() + buffer_size / 4,
+ [](uint8_t byte) { return byte == 2; }));
+ EXPECT_TRUE(std::all_of(region.begin() + buffer_size / 4, region.end(),
+ [](uint8_t byte) { return byte == 3; }));
+
+ buffer.Consume(buffer_size);
+ // Buffer contents should now be: |················|
+ EXPECT_TRUE(buffer.GetReadableRegion().empty());
}
-} // namespace platform
+} // namespace
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc
index 927ab3158e6..1be7c79aa08 100644
--- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc
@@ -18,6 +18,7 @@
#include <sstream>
#include <string>
#include <type_traits>
+#include <utility>
#include "absl/types/optional.h"
#include "platform/api/task_runner.h"
@@ -26,7 +27,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
namespace {
constexpr bool IsPowerOf2(uint32_t x) {
@@ -84,9 +84,10 @@ const SocketHandle& UdpSocketPosix::GetHandle() const {
}
// static
-ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(TaskRunner* task_runner,
- Client* client,
- const IPEndpoint& endpoint) {
+ErrorOr<std::unique_ptr<UdpSocket>> UdpSocket::Create(
+ TaskRunner* task_runner,
+ Client* client,
+ const IPEndpoint& endpoint) {
static std::atomic_bool in_create{false};
const bool in_create_local = in_create.exchange(true);
OSP_DCHECK_EQ(in_create_local, false)
@@ -112,8 +113,8 @@ ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(TaskRunner* task_runner,
return fd.error();
}
- auto socket = UdpSocketUniquePtr(static_cast<UdpSocket*>(new UdpSocketPosix(
- task_runner, client, SocketHandle(fd.value()), endpoint)));
+ std::unique_ptr<UdpSocket> socket = std::make_unique<UdpSocketPosix>(
+ task_runner, client, SocketHandle(fd.value()), endpoint);
in_create = false;
return socket;
}
@@ -346,11 +347,11 @@ uint16_t GetPortFromFromSockAddr(const sockaddr_in& sa) {
}
IPAddress GetIPAddressFromSockAddr(const sockaddr_in6& sa) {
- return IPAddress(sa.sin6_addr.s6_addr);
+ return IPAddress(IPAddress::Version::kV6, sa.sin6_addr.s6_addr);
}
IPAddress GetIPAddressFromPktInfo(const in6_pktinfo& pktinfo) {
- return IPAddress(pktinfo.ipi6_addr.s6_addr);
+ return IPAddress(IPAddress::Version::kV6, pktinfo.ipi6_addr.s6_addr);
}
uint16_t GetPortFromFromSockAddr(const sockaddr_in6& sa) {
@@ -618,5 +619,4 @@ void UdpSocketPosix::Close() {
handle_.fd = -1;
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h
index ba44d61ab21..6dc7ae538ce 100644
--- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h
@@ -10,10 +10,9 @@
#include "platform/base/macros.h"
#include "platform/impl/platform_client_posix.h"
#include "platform/impl/socket_handle_posix.h"
-#include "platform/impl/weak_ptr.h"
+#include "util/weak_ptr.h"
namespace openscreen {
-namespace platform {
class UdpSocketReaderPosix;
@@ -87,7 +86,6 @@ class UdpSocketPosix : public UdpSocket {
OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocketPosix);
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_UDP_SOCKET_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc
index 3892081b5bd..4667df883c9 100644
--- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.cc
@@ -12,7 +12,6 @@
#include "util/logging.h"
namespace openscreen {
-namespace platform {
UdpSocketReaderPosix::UdpSocketReaderPosix(SocketHandleWaiter* waiter)
: waiter_(waiter) {}
@@ -34,9 +33,11 @@ void UdpSocketReaderPosix::ProcessReadyHandle(SocketHandleRef handle) {
}
void UdpSocketReaderPosix::OnCreate(UdpSocket* socket) {
- std::lock_guard<std::mutex> lock(mutex_);
UdpSocketPosix* read_socket = static_cast<UdpSocketPosix*>(socket);
- sockets_.push_back(read_socket);
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ sockets_.push_back(read_socket);
+ }
waiter_->Subscribe(this, std::cref(read_socket->GetHandle()));
}
@@ -64,5 +65,4 @@ bool UdpSocketReaderPosix::IsMappedReadForTesting(
return std::find(sockets_.begin(), sockets_.end(), socket) != sockets_.end();
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h
index 5b1f615b6fb..d9e065d9526 100644
--- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h
+++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h
@@ -16,7 +16,6 @@
#include "platform/impl/udp_socket_posix.h"
namespace openscreen {
-namespace platform {
// This is the class responsible for watching sockets for readable data, then
// calling the function associated with these sockets once that data is read.
@@ -72,7 +71,6 @@ class UdpSocketReaderPosix : public SocketHandleWaiter::Subscriber {
friend class TestingUdpSocketReader;
};
-} // namespace platform
} // namespace openscreen
#endif // PLATFORM_IMPL_UDP_SOCKET_READER_POSIX_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc
index a09fe29edf1..7faec354b7c 100644
--- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc
+++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix_unittest.cc
@@ -14,7 +14,7 @@
#include "platform/test/fake_udp_socket.h"
namespace openscreen {
-namespace platform {
+namespace {
using namespace ::testing;
using ::testing::_;
@@ -48,10 +48,15 @@ class MockUdpSocketPosix : public UdpSocketPosix {
// Mock event waiter
class MockNetworkWaiter final : public SocketHandleWaiter {
public:
+ MockNetworkWaiter() : SocketHandleWaiter(&FakeClock::now) {}
+ ~MockNetworkWaiter() override = default;
+
MOCK_METHOD2(
AwaitSocketsReadable,
ErrorOr<std::vector<SocketHandleRef>>(const std::vector<SocketHandleRef>&,
const Clock::duration&));
+
+ FakeClock fake_clock{Clock::time_point{Clock::duration{1234567}}};
};
// Mock Task Runner
@@ -78,6 +83,8 @@ class MockTaskRunner final : public TaskRunner {
uint32_t delayed_tasks_posted;
};
+} // namespace
+
// Class extending NetworkWaiter to allow for looking at protected data.
class TestingUdpSocketReader final : public UdpSocketReaderPosix {
public:
@@ -129,5 +136,4 @@ TEST(UdpSocketReaderTest, UnwatchReadableSucceeds) {
EXPECT_FALSE(network_waiter.IsMappedRead(socket.get()));
}
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn b/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn
new file mode 100644
index 00000000000..9c28a2a86f9
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/libfuzzer/BUILD.gn
@@ -0,0 +1,29 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+# LibFuzzer is a LLVM tool for coverage-guided fuzz testing.
+# See http://www.chromium.org/developers/testing/libfuzzer
+
+import("//build_overrides/build.gni")
+
+source_set("fuzzing_engine_main") {
+ deps = [
+ "//third_party/libfuzzer",
+ ]
+ sources = []
+}
+
+# A config used by all fuzzer_tests.
+config("fuzzer_test_config") {
+ if (is_mac) {
+ ldflags = [
+ "-Wl,-U,_LLVMFuzzerCustomMutator",
+ "-Wl,-U,_LLVMFuzzerInitialize",
+ ]
+ }
+}
+
+# noop to tag seed corpus rules.
+source_set("seed_corpus") {
+}
diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py b/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py
new file mode 100755
index 00000000000..e80848ddfa2
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/libfuzzer/archive_corpus.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python2
+#
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Archive corpus file into zip and generate .d depfile.
+
+Invoked by GN from fuzzer_test.gni.
+"""
+
+from __future__ import print_function
+import argparse
+import os
+import sys
+import warnings
+import zipfile
+
+SEED_CORPUS_LIMIT_MB = 100
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate fuzzer config.")
+ parser.add_argument('corpus_directories', metavar='corpus_dir', type=str,
+ nargs='+')
+ parser.add_argument('--output', metavar='output_archive_name.zip',
+ required=True)
+
+ args = parser.parse_args()
+ corpus_files = []
+ seed_corpus_path = args.output
+
+ for directory in args.corpus_directories:
+ if not os.path.exists(directory):
+ raise Exception('The given seed_corpus directory (%s) does not exist.' %
+ directory)
+ for (dirpath, _, filenames) in os.walk(directory):
+ for filename in filenames:
+ full_filename = os.path.join(dirpath, filename)
+ corpus_files.append(full_filename)
+
+ with zipfile.ZipFile(seed_corpus_path, 'w') as z:
+ # Turn warnings into errors to interrupt the build: crbug.com/653920.
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ for i, corpus_file in enumerate(corpus_files):
+ # To avoid duplication of filenames inside the archive, use numbers.
+ arcname = '%016d' % i
+ z.write(corpus_file, arcname)
+
+ if os.path.getsize(seed_corpus_path) > SEED_CORPUS_LIMIT_MB * 1024 * 1024:
+ print('Seed corpus %s exceeds maximum allowed size (%d MB).' %
+ (seed_corpus_path, SEED_CORPUS_LIMIT_MB))
+ sys.exit(-1)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni
new file mode 100644
index 00000000000..4de38ac46e7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni
@@ -0,0 +1,193 @@
+# Copyright 2015 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+# Defines fuzzer_test.
+#
+import("//build_overrides/build.gni")
+
+# fuzzer_test is used to define individual libfuzzer tests.
+#
+# Supported attributes:
+# - (required) sources - fuzzer test source files
+# - deps - test dependencies
+# - libs - Additional libraries to link.
+# - additional_configs - additional configs to be used for compilation
+# - dict - a dictionary file for the fuzzer.
+# - environment_variables - certain whitelisted environment variables for the
+# fuzzer (AFL_DRIVER_DONT_DEFER is the only one allowed currently).
+# - libfuzzer_options - options for the fuzzer (e.g. -close_fd_mask=N).
+# - asan_options - AddressSanitizer options (e.g. allow_user_segv_handler=1).
+# - msan_options - MemorySanitizer options.
+# - ubsan_options - UndefinedBehaviorSanitizer options.
+# - seed_corpus - a directory with seed corpus.
+# - seed_corpus_deps - dependencies for generating the seed corpus.
+#
+# If use_libfuzzer gn flag is defined, then proper fuzzer would be build.
+# Without use_libfuzzer or use_afl a unit-test style binary would be built on
+# linux and the whole target is a no-op otherwise.
+#
+# The template wraps test() target with appropriate dependencies.
+# If any test run-time options are present (dict or libfuzzer_options), then a
+# config (.options file) file would be generated or modified in root output
+# dir (next to test).
+template("openscreen_fuzzer_test") {
+ if (is_clang && !build_with_chromium) {
+ assert(defined(invoker.sources), "Need sources in $target_name.")
+
+ test_deps = [ "//testing/libfuzzer:fuzzing_engine_main" ]
+ test_data_deps = []
+
+ if (defined(invoker.deps)) {
+ test_deps += invoker.deps
+ }
+ if (defined(invoker.data_deps)) {
+ test_data_deps += invoker.data_deps
+ }
+
+ if (defined(invoker.seed_corpus) || defined(invoker.seed_corpuses)) {
+ assert(!(defined(invoker.seed_corpus) && defined(invoker.seed_corpuses)),
+ "Do not use both seed_corpus and seed_corpuses for $target_name.")
+
+ out = "$root_build_dir/$target_name" + "_seed_corpus.zip"
+
+ seed_corpus_deps = []
+
+ if (defined(invoker.seed_corpus_deps)) {
+ seed_corpus_deps += invoker.seed_corpus_deps
+ }
+
+ action(target_name + "_seed_corpus") {
+ script = "//testing/libfuzzer/archive_corpus.py"
+
+ args = [
+ "--output",
+ rebase_path(out, root_build_dir),
+ ]
+
+ if (defined(invoker.seed_corpus)) {
+ args += [ rebase_path(invoker.seed_corpus, root_build_dir) ]
+ }
+
+ if (defined(invoker.seed_corpuses)) {
+ foreach(seed_corpus_path, invoker.seed_corpuses) {
+ args += [ rebase_path(seed_corpus_path, root_build_dir) ]
+ }
+ }
+
+ outputs = [
+ out,
+ ]
+
+ deps = [ "//testing/libfuzzer:seed_corpus" ] + seed_corpus_deps
+ }
+
+ test_deps += [ ":" + target_name + "_seed_corpus" ]
+ }
+
+ if (defined(invoker.dict) || defined(invoker.libfuzzer_options) ||
+ defined(invoker.asan_options) || defined(invoker.msan_options) ||
+ defined(invoker.ubsan_options) ||
+ defined(invoker.environment_variables)) {
+ if (defined(invoker.dict)) {
+ # Copy dictionary to output.
+ copy(target_name + "_dict_copy") {
+ sources = [
+ invoker.dict,
+ ]
+ outputs = [
+ "$root_build_dir/" + target_name + ".dict",
+ ]
+ }
+ test_deps += [ ":" + target_name + "_dict_copy" ]
+ }
+
+ # Generate .options file.
+ config_file_name = target_name + ".options"
+ action(config_file_name) {
+ script = "//testing/libfuzzer/gen_fuzzer_config.py"
+ args = [
+ "--config",
+ rebase_path("$root_build_dir/" + config_file_name, root_build_dir),
+ ]
+
+ if (defined(invoker.dict)) {
+ args += [
+ "--dict",
+ rebase_path("$root_build_dir/" + invoker.target_name + ".dict",
+ root_build_dir),
+ ]
+ }
+
+ if (defined(invoker.libfuzzer_options)) {
+ args += [ "--libfuzzer_options" ]
+ args += invoker.libfuzzer_options
+ }
+
+ if (defined(invoker.asan_options)) {
+ args += [ "--asan_options" ]
+ args += invoker.asan_options
+ }
+
+ if (defined(invoker.msan_options)) {
+ args += [ "--msan_options" ]
+ args += invoker.msan_options
+ }
+
+ if (defined(invoker.ubsan_options)) {
+ args += [ "--ubsan_options" ]
+ args += invoker.ubsan_options
+ }
+
+ if (defined(invoker.environment_variables)) {
+ args += [ "--environment_variables" ]
+ args += invoker.environment_variables
+ }
+
+ outputs = [
+ "$root_build_dir/$config_file_name",
+ ]
+ }
+ test_deps += [ ":" + config_file_name ]
+ }
+
+ executable(target_name) {
+ forward_variables_from(invoker,
+ [
+ "cflags",
+ "cflags_cc",
+ "check_includes",
+ "defines",
+ "include_dirs",
+ "sources",
+ "libs",
+ ])
+ deps = test_deps
+ data_deps = test_data_deps
+
+ if (defined(invoker.additional_configs)) {
+ configs += invoker.additional_configs
+ }
+ configs += [ "//testing/libfuzzer:fuzzer_test_config" ]
+
+ if (defined(invoker.suppressed_configs)) {
+ configs -= invoker.suppressed_configs
+ }
+
+ if (defined(invoker.generated_sources)) {
+ sources += invoker.generated_sources
+ }
+
+ if (is_mac) {
+ sources += [ "//testing/libfuzzer/libfuzzer_exports.h" ]
+ }
+ }
+ } else {
+ # noop on unsupported platforms.
+ # mark attributes as used.
+ not_needed(invoker, "*")
+
+ group(target_name) {
+ }
+ }
+}
diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py b/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py
new file mode 100755
index 00000000000..bde9e146ed2
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/libfuzzer/gen_fuzzer_config.py
@@ -0,0 +1,86 @@
+#!/usr/bin/python2
+#
+# Copyright (c) 2015 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Generate or update an existing config (.options file) for libfuzzer test.
+
+Invoked by GN from fuzzer_test.gni.
+"""
+
+import ConfigParser
+import argparse
+import os
+import sys
+
+
+def AddSectionOptions(config, section_name, options):
+ """Add |options| to the |section_name| section of |config|.
+
+ Throws an
+ assertion error if any option in |options| does not have exactly two
+ elements.
+ """
+ if not options:
+ return
+
+ config.add_section(section_name)
+ for option_and_value in options:
+ assert len(option_and_value) == 2, (
+ '%s is not an option, value pair' % option_and_value)
+
+ config.set(section_name, *option_and_value)
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Generate fuzzer config.')
+ parser.add_argument('--config', required=True)
+ parser.add_argument('--dict')
+ parser.add_argument('--libfuzzer_options', nargs='+', default=[])
+ parser.add_argument('--asan_options', nargs='+', default=[])
+ parser.add_argument('--msan_options', nargs='+', default=[])
+ parser.add_argument('--ubsan_options', nargs='+', default=[])
+ parser.add_argument(
+ '--environment_variables',
+ nargs='+',
+ default=[],
+ choices=['AFL_DRIVER_DONT_DEFER=1'])
+ args = parser.parse_args()
+
+ # Script shouldn't be invoked without any arguments, but just in case.
+ if not (args.dict or args.libfuzzer_options or args.environment_variables or
+ args.asan_options or args.msan_options or args.ubsan_options):
+ return
+
+ config = ConfigParser.ConfigParser()
+ libfuzzer_options = []
+ if args.dict:
+ libfuzzer_options.append(('dict', os.path.basename(args.dict)))
+ libfuzzer_options.extend(
+ option.split('=') for option in args.libfuzzer_options)
+
+ AddSectionOptions(config, 'libfuzzer', libfuzzer_options)
+
+ AddSectionOptions(config, 'asan',
+ [option.split('=') for option in args.asan_options])
+
+ AddSectionOptions(config, 'msan',
+ [option.split('=') for option in args.msan_options])
+
+ AddSectionOptions(config, 'ubsan',
+ [option.split('=') for option in args.ubsan_options])
+
+ AddSectionOptions(
+ config, 'env',
+ [option.split('=') for option in args.environment_variables])
+
+ # Generate .options file.
+ config_path = args.config
+ with open(config_path, 'w') as options_file:
+ options_file.write(
+ '# This is an automatically generated config for ClusterFuzz.\n')
+ config.write(options_file)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/chromium/third_party/openscreen/src/testing/util/BUILD.gn b/chromium/third_party/openscreen/src/testing/util/BUILD.gn
new file mode 100644
index 00000000000..b96a6b4955d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/util/BUILD.gn
@@ -0,0 +1,17 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("util") {
+ testonly = true
+ sources = [
+ "read_file.cc",
+ "read_file.h",
+ ]
+
+ public_deps = [
+ "../../third_party/abseil",
+ ]
+
+ public_configs = [ "../../build:openscreen_include_dirs" ]
+}
diff --git a/chromium/third_party/openscreen/src/testing/util/read_file.cc b/chromium/third_party/openscreen/src/testing/util/read_file.cc
new file mode 100644
index 00000000000..be0df7c80da
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/util/read_file.cc
@@ -0,0 +1,34 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "testing/util/read_file.h"
+
+#include <stdio.h>
+
+namespace openscreen {
+
+std::string ReadEntireFileToString(absl::string_view filename) {
+ FILE* file = fopen(filename.data(), "r");
+ if (file == nullptr) {
+ return {};
+ }
+ fseek(file, 0, SEEK_END);
+ long file_size = ftell(file);
+ fseek(file, 0, SEEK_SET);
+ std::string contents(file_size, 0);
+ int bytes_read = 0;
+ while (bytes_read < file_size) {
+ size_t ret = fread(&contents[bytes_read], 1, file_size - bytes_read, file);
+ if (ret == 0 && ferror(file)) {
+ return {};
+ } else {
+ bytes_read += ret;
+ }
+ }
+ fclose(file);
+
+ return contents;
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/testing/util/read_file.h b/chromium/third_party/openscreen/src/testing/util/read_file.h
new file mode 100644
index 00000000000..7203518b47d
--- /dev/null
+++ b/chromium/third_party/openscreen/src/testing/util/read_file.h
@@ -0,0 +1,18 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef TESTING_UTIL_READ_FILE_H_
+#define TESTING_UTIL_READ_FILE_H_
+
+#include <string>
+
+#include "absl/strings/string_view.h"
+
+namespace openscreen {
+
+std::string ReadEntireFileToString(absl::string_view filename);
+
+} // namespace openscreen
+
+#endif // TESTING_UTIL_READ_FILE_H_
diff --git a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn
index 4c512c8b8f3..1c2ea384e0e 100644
--- a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn
@@ -72,6 +72,8 @@ if (build_with_chromium) {
"src/absl/strings/str_cat.cc",
"src/absl/strings/str_cat.h",
"src/absl/strings/str_join.h",
+ "src/absl/strings/str_replace.cc",
+ "src/absl/strings/str_replace.h",
"src/absl/strings/str_split.cc",
"src/absl/strings/str_split.h",
"src/absl/strings/string_view.cc",
diff --git a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn
index 6bfebca586d..22e504fed83 100644
--- a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn
@@ -32,18 +32,8 @@ if (build_with_chromium) {
"BORINGSSL_NO_STATIC_INITIALIZER",
"OPENSSL_SMALL",
]
- cflags = []
+ cflags = [ "-w" ] # Disable all warnings.
cflags_c = [ "-std=c99" ]
- cflags_cc = []
- if (is_clang) {
- cflags += [ "-Wno-extra-semi" ]
- cflags_cc += [ "-Wno-c++98-compat-extra-semi" ]
- }
-
- if (is_mac) {
- # Necessary since trybots have an old version of clang.
- cflags += [ "-Wno-unknown-warning-option" ]
- }
defines += [ "_XOPEN_SOURCE=700" ]
}
diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn b/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn
index 2cf14c00389..e620c3a6a94 100644
--- a/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/BUILD.gn
@@ -5,37 +5,7 @@
import("//third_party/protobuf/proto_library.gni")
config("chromium_quic_config") {
- cflags_cc = [
- "-Wno-error=attributes",
- "-Wno-unused-const-variable",
- ]
-
- if (is_clang) {
- cflags_cc += [
- "-Wno-defaulted-function-deleted",
- "-Wno-c++98-compat-extra-semi",
- "-Wno-extra-semi",
- ]
-
- # The clang version on the Mac OS X build bots is old, causing them to
- # not recognize the defaulted-function-deleted warning. This flag allows
- # us to build on both older and newer Mac OS X Clang toolchains.
- if (is_mac) {
- cflags_cc += [ "-Wno-unknown-warning-option" ]
- }
- }
-
- if (is_gcc) {
- cflags_cc += [
- "-Wno-dangling-else",
- "-Wno-return-type",
- "-Wno-unused-but-set-variable",
-
- # Don't warn about "maybe" uninitialized. Clang doesn't include this
- # in -Wall but gcc does, and it gives false positives.
- "-Wno-maybe-uninitialized",
- ]
- }
+ cflags = [ "-w" ] # Disable all warnings.
configs = [ "//third_party/protobuf:protobuf_config" ]
@@ -56,7 +26,7 @@ source_set("chromium_quic") {
]
}
-executable("quic_demo_client") {
+executable("quic_streaming_playback_controller") {
sources = [
"demo/client.cc",
"demo/delegates.cc",
diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn b/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn
index aefbef77369..72d3975419e 100644
--- a/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/build/base/BUILD.gn
@@ -547,7 +547,11 @@ source_set("base") {
"synchronization/synchronization_buildflags.h",
]
- if (is_posix || is_linux || is_mac) {
+ if (is_posix) {
+ if (target_cpu == "arm") {
+ cflags_cc = [ "-D_FILE_OFFSET_BITS=64" ]
+ }
+
sources += [
"../../src/base/base_paths_posix.h",
"../../src/base/debug/debugger_posix.cc",
@@ -617,7 +621,6 @@ source_set("base") {
"../../src/base/files/file_util_mac.mm",
"../../src/base/mac/authorization_util.h",
"../../src/base/mac/authorization_util.mm",
- "../../src/base/mac/availability.h",
"../../src/base/mac/bundle_locations.h",
"../../src/base/mac/bundle_locations.mm",
"../../src/base/mac/call_with_eh_frame.cc",
diff --git a/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc b/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc
index e57e97d34b2..d3eb2566708 100644
--- a/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc
+++ b/chromium/third_party/openscreen/src/third_party/chromium_quic/demo/client.cc
@@ -35,7 +35,8 @@ int main(int argc, char** argv) {
if (argc < 2) {
dprintf(STDERR_FILENO,
- "Missing port number\nusage: demo_client <server-port>\n");
+ "Missing port number\nusage: streaming_playback_controller "
+ "<server-port>\n");
return 1;
}
int port = atoi(argv[1]);
diff --git a/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn b/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn
index f76f9e749a1..27006582301 100644
--- a/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/googletest/BUILD.gn
@@ -77,6 +77,10 @@ if (build_with_chromium) {
":gtest_config",
]
+ public_deps = [
+ ":gtest",
+ ]
+
include_dirs = [ "src/googlemock" ]
}
diff --git a/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn b/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn
new file mode 100644
index 00000000000..a1a3db59394
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/libfuzzer/BUILD.gn
@@ -0,0 +1,44 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import("//build_overrides/build.gni")
+
+config("ignore_warnings") {
+ if (is_clang) {
+ cflags_cc = [
+ "-Wno-unused-result",
+ "-Wno-exit-time-destructors",
+ ]
+ }
+}
+
+source_set("libfuzzer") {
+ sources = [
+ "src/FuzzerCrossOver.cpp",
+ "src/FuzzerDataFlowTrace.cpp",
+ "src/FuzzerDriver.cpp",
+ "src/FuzzerExtFunctionsDlsym.cpp",
+ "src/FuzzerExtFunctionsWeak.cpp",
+ "src/FuzzerExtFunctionsWindows.cpp",
+ "src/FuzzerExtraCounters.cpp",
+ "src/FuzzerFork.cpp",
+ "src/FuzzerIO.cpp",
+ "src/FuzzerIOPosix.cpp",
+ "src/FuzzerIOWindows.cpp",
+ "src/FuzzerLoop.cpp",
+ "src/FuzzerMain.cpp",
+ "src/FuzzerMerge.cpp",
+ "src/FuzzerMutate.cpp",
+ "src/FuzzerSHA1.cpp",
+ "src/FuzzerTracePC.cpp",
+ "src/FuzzerUtil.cpp",
+ "src/FuzzerUtilDarwin.cpp",
+ "src/FuzzerUtilFuchsia.cpp",
+ "src/FuzzerUtilLinux.cpp",
+ "src/FuzzerUtilPosix.cpp",
+ "src/FuzzerUtilWindows.cpp",
+ ]
+
+ configs += [ ":ignore_warnings" ]
+}
diff --git a/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn b/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn
index 816e0906b6e..4e2fb8d4b0b 100644
--- a/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/mDNSResponder/BUILD.gn
@@ -3,20 +3,9 @@
# found in the LICENSE file.
config("mdnsresponder_config") {
- cflags_c = [ "-Wno-array-bounds" ]
+ cflags = [ "-w" ] # Disable all warnings.
- if (is_gcc) {
- cflags_c += [
- "-Wno-unused-but-set-variable",
- "-Wno-unused-value",
- ]
- }
-
- if (is_clang) {
- cflags_c += [ "-Wno-address-of-packed-member" ]
- }
-
- cflags_c += [
+ cflags_c = [
# We need to rename some linked symbols in order to avoid multiple
# definitions.
"-DMD5_Update=MD5_Update_mDNS",
@@ -24,6 +13,10 @@ config("mdnsresponder_config") {
"-DMD5_Final=MD5_Final_mDNS",
"-DMD5_Transform=MD5_Transform_mDNS",
]
+
+ if (target_cpu == "arm" || target_cpu == "arm64") {
+ cflags_c += [ "-Dmd5_block_data_order=md5_block_data_order" ]
+ }
}
source_set("core") {
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn b/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn
new file mode 100644
index 00000000000..051917d2828
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/BUILD.gn
@@ -0,0 +1,14 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("mozilla") {
+ sources = [
+ "url_parse.cc",
+ "url_parse.h",
+ "url_parse_internal.cc",
+ "url_parse_internal.h",
+ ]
+
+ public_configs = [ "../../build:openscreen_include_dirs" ]
+}
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt b/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt
new file mode 100644
index 00000000000..ac40837824a
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/LICENSE.txt
@@ -0,0 +1,65 @@
+Copyright 2007, Google Inc.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+-------------------------------------------------------------------------------
+
+The file url_parse.cc is based on nsURLParsers.cc from Mozilla. This file is
+licensed separately as follows:
+
+The contents of this file are subject to the Mozilla Public License Version
+1.1 (the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+http://www.mozilla.org/MPL/
+
+Software distributed under the License is distributed on an "AS IS" basis,
+WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
+for the specific language governing rights and limitations under the
+License.
+
+The Original Code is mozilla.org code.
+
+The Initial Developer of the Original Code is
+Netscape Communications Corporation.
+Portions created by the Initial Developer are Copyright (C) 1998
+the Initial Developer. All Rights Reserved.
+
+Contributor(s):
+ Darin Fisher (original author)
+
+Alternatively, the contents of this file may be used under the terms of
+either the GNU General Public License Version 2 or later (the "GPL"), or
+the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
+in which case the provisions of the GPL or the LGPL are applicable instead
+of those above. If you wish to allow use of your version of this file only
+under the terms of either the GPL or the LGPL, and not to allow others to
+use your version of this file under the terms of the MPL, indicate your
+decision by deleting the provisions above and replace them with the notice
+and other provisions required by the GPL or the LGPL. If you do not delete
+the provisions above, a recipient may use your version of this file under
+the terms of any one of the MPL, the GPL or the LGPL.
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/README.md b/chromium/third_party/openscreen/src/third_party/mozilla/README.md
new file mode 100644
index 00000000000..ed4c24d8c06
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/README.md
@@ -0,0 +1,7 @@
+# url_parse
+
+`url_parse.{h,cc}` are based on the same files in Chromium under
+`//url/third_party/mozilla` but have been slightly modified for our use case.
+`url_parse_internal.{h,cc}` contains additional functions needed by the former
+files but aren't provided directly. These are also ported from Chromium's
+version.
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc
new file mode 100644
index 00000000000..e6efd9e7a34
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.cc
@@ -0,0 +1,858 @@
+/* Based on nsURLParsers.cc from Mozilla
+ * -------------------------------------
+ * The contents of this file are subject to the Mozilla Public License Version
+ * 1.1 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.mozilla.org/MPL/
+ *
+ * Software distributed under the License is distributed on an "AS IS" basis,
+ * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
+ * for the specific language governing rights and limitations under the
+ * License.
+ *
+ * The Original Code is mozilla.org code.
+ *
+ * The Initial Developer of the Original Code is
+ * Netscape Communications Corporation.
+ * Portions created by the Initial Developer are Copyright (C) 1998
+ * the Initial Developer. All Rights Reserved.
+ *
+ * Contributor(s):
+ * Darin Fisher (original author)
+ *
+ * Alternatively, the contents of this file may be used under the terms of
+ * either the GNU General Public License Version 2 or later (the "GPL"), or
+ * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
+ * in which case the provisions of the GPL or the LGPL are applicable instead
+ * of those above. If you wish to allow use of your version of this file only
+ * under the terms of either the GPL or the LGPL, and not to allow others to
+ * use your version of this file under the terms of the MPL, indicate your
+ * decision by deleting the provisions above and replace them with the notice
+ * and other provisions required by the GPL or the LGPL. If you do not delete
+ * the provisions above, a recipient may use your version of this file under
+ * the terms of any one of the MPL, the GPL or the LGPL.
+ *
+ * ***** END LICENSE BLOCK ***** */
+
+#include "third_party/mozilla/url_parse.h"
+
+#include <assert.h>
+#include <ctype.h>
+#include <stdlib.h>
+
+#include "third_party/mozilla/url_parse_internal.h"
+
+namespace openscreen {
+namespace {
+
+// Returns true if the given character is a valid digit to use in a port.
+bool IsPortDigit(char ch) {
+ return ch >= '0' && ch <= '9';
+}
+
+// Returns the offset of the next authority terminator in the input starting
+// from start_offset. If no terminator is found, the return value will be equal
+// to spec_len.
+int FindNextAuthorityTerminator(const char* spec,
+ int start_offset,
+ int spec_len) {
+ for (int i = start_offset; i < spec_len; i++) {
+ if (IsAuthorityTerminator(spec[i]))
+ return i;
+ }
+ return spec_len; // Not found.
+}
+
+void ParseUserInfo(const char* spec,
+ const Component& user,
+ Component* username,
+ Component* password) {
+ // Find the first colon in the user section, which separates the username and
+ // password.
+ int colon_offset = 0;
+ while (colon_offset < user.len && spec[user.begin + colon_offset] != ':')
+ colon_offset++;
+
+ if (colon_offset < user.len) {
+ // Found separator: <username>:<password>
+ *username = Component(user.begin, colon_offset);
+ *password = MakeRange(user.begin + colon_offset + 1, user.begin + user.len);
+ } else {
+ // No separator, treat everything as the username
+ *username = user;
+ *password = Component();
+ }
+}
+
+void ParseServerInfo(const char* spec,
+ const Component& serverinfo,
+ Component* hostname,
+ Component* port_num) {
+ if (serverinfo.len == 0) {
+ // No server info, host name is empty.
+ hostname->reset();
+ port_num->reset();
+ return;
+ }
+
+ // If the host starts with a left-bracket, assume the entire host is an
+ // IPv6 literal. Otherwise, assume none of the host is an IPv6 literal.
+ // This assumption will be overridden if we find a right-bracket.
+ //
+ // Our IPv6 address canonicalization code requires both brackets to exist,
+ // but the ability to locate an incomplete address can still be useful.
+ int ipv6_terminator = spec[serverinfo.begin] == '[' ? serverinfo.end() : -1;
+ int colon = -1;
+
+ // Find the last right-bracket, and the last colon.
+ for (int i = serverinfo.begin; i < serverinfo.end(); i++) {
+ switch (spec[i]) {
+ case ']':
+ ipv6_terminator = i;
+ break;
+ case ':':
+ colon = i;
+ break;
+ }
+ }
+
+ if (colon > ipv6_terminator) {
+ // Found a port number: <hostname>:<port>
+ *hostname = MakeRange(serverinfo.begin, colon);
+ if (hostname->len == 0)
+ hostname->reset();
+ *port_num = MakeRange(colon + 1, serverinfo.end());
+ } else {
+ // No port: <hostname>
+ *hostname = serverinfo;
+ port_num->reset();
+ }
+}
+
+// Given an already-identified auth section, breaks it into its consituent
+// parts. The port number will be parsed and the resulting integer will be
+// filled into the given *port variable, or -1 if there is no port number or it
+// is invalid.
+void DoParseAuthority(const char* spec,
+ const Component& auth,
+ Component* username,
+ Component* password,
+ Component* hostname,
+ Component* port_num) {
+ assert(auth.is_valid());
+ if (auth.len == 0) {
+ username->reset();
+ password->reset();
+ hostname->reset();
+ port_num->reset();
+ return;
+ }
+
+ // Search backwards for @, which is the separator between the user info and
+ // the server info.
+ int i = auth.begin + auth.len - 1;
+ while (i > auth.begin && spec[i] != '@')
+ i--;
+
+ if (spec[i] == '@') {
+ // Found user info: <user-info>@<server-info>
+ ParseUserInfo(spec, Component(auth.begin, i - auth.begin), username,
+ password);
+ ParseServerInfo(spec, MakeRange(i + 1, auth.begin + auth.len), hostname,
+ port_num);
+ } else {
+ // No user info, everything is server info.
+ username->reset();
+ password->reset();
+ ParseServerInfo(spec, auth, hostname, port_num);
+ }
+}
+
+inline void FindQueryAndRefParts(const char* spec,
+ const Component& path,
+ int* query_separator,
+ int* ref_separator) {
+ int path_end = path.begin + path.len;
+ for (int i = path.begin; i < path_end; i++) {
+ switch (spec[i]) {
+ case '?':
+ // Only match the query string if it precedes the reference fragment
+ // and when we haven't found one already.
+ if (*query_separator < 0)
+ *query_separator = i;
+ break;
+ case '#':
+ // Record the first # sign only.
+ if (*ref_separator < 0) {
+ *ref_separator = i;
+ return;
+ }
+ break;
+ }
+ }
+}
+
+void ParsePath(const char* spec,
+ const Component& path,
+ Component* filepath,
+ Component* query,
+ Component* ref) {
+ // path = [/]<segment1>/<segment2>/<...>/<segmentN>;<param>?<query>#<ref>
+
+ // Special case when there is no path.
+ if (path.len == -1) {
+ filepath->reset();
+ query->reset();
+ ref->reset();
+ return;
+ }
+ assert(path.len > 0);
+
+ // Search for first occurrence of either ? or #.
+ int query_separator = -1; // Index of the '?'
+ int ref_separator = -1; // Index of the '#'
+ FindQueryAndRefParts(spec, path, &query_separator, &ref_separator);
+
+ // Markers pointing to the character after each of these corresponding
+ // components. The code below words from the end back to the beginning,
+ // and will update these indices as it finds components that exist.
+ int file_end, query_end;
+
+ // Ref fragment: from the # to the end of the path.
+ int path_end = path.begin + path.len;
+ if (ref_separator >= 0) {
+ file_end = query_end = ref_separator;
+ *ref = MakeRange(ref_separator + 1, path_end);
+ } else {
+ file_end = query_end = path_end;
+ ref->reset();
+ }
+
+ // Query fragment: everything from the ? to the next boundary (either the end
+ // of the path or the ref fragment).
+ if (query_separator >= 0) {
+ file_end = query_separator;
+ *query = MakeRange(query_separator + 1, query_end);
+ } else {
+ query->reset();
+ }
+
+ // File path: treat an empty file path as no file path.
+ if (file_end != path.begin)
+ *filepath = MakeRange(path.begin, file_end);
+ else
+ filepath->reset();
+}
+
+bool DoExtractScheme(const char* url, int url_len, Component* scheme) {
+ // Skip leading whitespace and control characters.
+ int begin = 0;
+ while (begin < url_len && ShouldTrimFromURL(url[begin]))
+ begin++;
+ if (begin == url_len)
+ return false; // Input is empty or all whitespace.
+
+ // Find the first colon character.
+ for (int i = begin; i < url_len; i++) {
+ if (url[i] == ':') {
+ *scheme = MakeRange(begin, i);
+ return true;
+ }
+ }
+ return false; // No colon found: no scheme
+}
+
+// Fills in all members of the Parsed structure except for the scheme.
+//
+// |spec| is the full spec being parsed, of length |spec_len|.
+// |after_scheme| is the character immediately following the scheme (after the
+// colon) where we'll begin parsing.
+//
+// Compatability data points. I list "host", "path" extracted:
+// Input IE6 Firefox Us
+// ----- -------------- -------------- --------------
+// http://foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/"
+// http:foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/"
+// http:/foo.com/ fail(*) "foo.com", "/" "foo.com", "/"
+// http:\foo.com/ fail(*) "\foo.com", "/"(fail) "foo.com", "/"
+// http:////foo.com/ "foo.com", "/" "foo.com", "/" "foo.com", "/"
+//
+// (*) Interestingly, although IE fails to load these URLs, its history
+// canonicalizer handles them, meaning if you've been to the corresponding
+// "http://foo.com/" link, it will be colored.
+void DoParseAfterScheme(const char* spec,
+ int spec_len,
+ int after_scheme,
+ Parsed* parsed) {
+ int num_slashes = CountConsecutiveSlashes(spec, after_scheme, spec_len);
+ int after_slashes = after_scheme + num_slashes;
+
+ // First split into two main parts, the authority (username, password, host,
+ // and port) and the full path (path, query, and reference).
+ Component authority;
+ Component full_path;
+
+ // Found "//<some data>", looks like an authority section. Treat everything
+ // from there to the next slash (or end of spec) to be the authority. Note
+ // that we ignore the number of slashes and treat it as the authority.
+ int end_auth = FindNextAuthorityTerminator(spec, after_slashes, spec_len);
+ authority = Component(after_slashes, end_auth - after_slashes);
+
+ if (end_auth == spec_len) // No beginning of path found.
+ full_path = Component();
+ else // Everything starting from the slash to the end is the path.
+ full_path = Component(end_auth, spec_len - end_auth);
+
+ // Now parse those two sub-parts.
+ DoParseAuthority(spec, authority, &parsed->username, &parsed->password,
+ &parsed->host, &parsed->port);
+ ParsePath(spec, full_path, &parsed->path, &parsed->query, &parsed->ref);
+}
+
+// The main parsing function for standard URLs. Standard URLs have a scheme,
+// host, path, etc.
+void DoParseStandardURL(const char* spec, int spec_len, Parsed* parsed) {
+ assert(spec_len >= 0);
+
+ // Strip leading & trailing spaces and control characters.
+ int begin = 0;
+ TrimURL(spec, &begin, &spec_len);
+
+ int after_scheme;
+ if (DoExtractScheme(spec, spec_len, &parsed->scheme)) {
+ after_scheme = parsed->scheme.end() + 1; // Skip past the colon.
+ } else {
+ // Say there's no scheme when there is no colon. We could also say that
+ // everything is the scheme. Both would produce an invalid URL, but this way
+ // seems less wrong in more cases.
+ parsed->scheme.reset();
+ after_scheme = begin;
+ }
+ DoParseAfterScheme(spec, spec_len, after_scheme, parsed);
+}
+
+void DoParseFileSystemURL(const char* spec, int spec_len, Parsed* parsed) {
+ assert(spec_len >= 0);
+
+ // Get the unused parts of the URL out of the way.
+ parsed->username.reset();
+ parsed->password.reset();
+ parsed->host.reset();
+ parsed->port.reset();
+ parsed->path.reset(); // May use this; reset for convenience.
+ parsed->ref.reset(); // May use this; reset for convenience.
+ parsed->query.reset(); // May use this; reset for convenience.
+ parsed->clear_inner_parsed(); // May use this; reset for convenience.
+
+ // Strip leading & trailing spaces and control characters.
+ int begin = 0;
+ TrimURL(spec, &begin, &spec_len);
+
+ // Handle empty specs or ones that contain only whitespace or control chars.
+ if (begin == spec_len) {
+ parsed->scheme.reset();
+ return;
+ }
+
+ int inner_start = -1;
+
+ // Extract the scheme. We also handle the case where there is no scheme.
+ if (DoExtractScheme(&spec[begin], spec_len - begin, &parsed->scheme)) {
+ // Offset the results since we gave ExtractScheme a substring.
+ parsed->scheme.begin += begin;
+
+ if (parsed->scheme.end() == spec_len - 1)
+ return;
+
+ inner_start = parsed->scheme.end() + 1;
+ } else {
+ // No scheme found; that's not valid for filesystem URLs.
+ parsed->scheme.reset();
+ return;
+ }
+
+ Component inner_scheme;
+ const char* inner_spec = &spec[inner_start];
+ int inner_spec_len = spec_len - inner_start;
+
+ if (DoExtractScheme(inner_spec, inner_spec_len, &inner_scheme)) {
+ // Offset the results since we gave ExtractScheme a substring.
+ inner_scheme.begin += inner_start;
+
+ if (inner_scheme.end() == spec_len - 1)
+ return;
+ } else {
+ // No scheme found; that's not valid for filesystem URLs.
+ // The best we can do is return "filesystem://".
+ return;
+ }
+
+ Parsed inner_parsed;
+
+ if (CompareSchemeComponent(spec, inner_scheme, kFileScheme)) {
+ // File URLs are special.
+ ParseFileURL(inner_spec, inner_spec_len, &inner_parsed);
+ } else if (CompareSchemeComponent(spec, inner_scheme, kFileSystemScheme)) {
+ // Filesystem URLs don't nest.
+ return;
+ } else if (IsStandard(spec, inner_scheme)) {
+ // All "normal" URLs.
+ DoParseStandardURL(inner_spec, inner_spec_len, &inner_parsed);
+ } else {
+ return;
+ }
+
+ // All members of inner_parsed need to be offset by inner_start.
+ // If we had any scheme that supported nesting more than one level deep,
+ // we'd have to recurse into the inner_parsed's inner_parsed when
+ // adjusting by inner_start.
+ inner_parsed.scheme.begin += inner_start;
+ inner_parsed.username.begin += inner_start;
+ inner_parsed.password.begin += inner_start;
+ inner_parsed.host.begin += inner_start;
+ inner_parsed.port.begin += inner_start;
+ inner_parsed.query.begin += inner_start;
+ inner_parsed.ref.begin += inner_start;
+ inner_parsed.path.begin += inner_start;
+
+ // Query and ref move from inner_parsed to parsed.
+ parsed->query = inner_parsed.query;
+ inner_parsed.query.reset();
+ parsed->ref = inner_parsed.ref;
+ inner_parsed.ref.reset();
+
+ parsed->set_inner_parsed(inner_parsed);
+ if (!inner_parsed.scheme.is_valid() || !inner_parsed.path.is_valid() ||
+ inner_parsed.inner_parsed()) {
+ return;
+ }
+
+ // The path in inner_parsed should start with a slash, then have a filesystem
+ // type followed by a slash. From the first slash up to but excluding the
+ // second should be what it keeps; the rest goes to parsed. If the path ends
+ // before the second slash, it's still pretty clear what the user meant, so
+ // we'll let that through.
+ if (!IsURLSlash(spec[inner_parsed.path.begin])) {
+ return;
+ }
+ int inner_path_end = inner_parsed.path.begin + 1; // skip the leading slash
+ while (inner_path_end < spec_len && !IsURLSlash(spec[inner_path_end]))
+ ++inner_path_end;
+ parsed->path.begin = inner_path_end;
+ int new_inner_path_length = inner_path_end - inner_parsed.path.begin;
+ parsed->path.len = inner_parsed.path.len - new_inner_path_length;
+ parsed->inner_parsed()->path.len = new_inner_path_length;
+}
+
+// Initializes a path URL which is merely a scheme followed by a path. Examples
+// include "about:foo" and "javascript:alert('bar');"
+void DoParsePathURL(const char* spec,
+ int spec_len,
+ bool trim_path_end,
+ Parsed* parsed) {
+ // Get the non-path and non-scheme parts of the URL out of the way, we never
+ // use them.
+ parsed->username.reset();
+ parsed->password.reset();
+ parsed->host.reset();
+ parsed->port.reset();
+ parsed->path.reset();
+ parsed->query.reset();
+ parsed->ref.reset();
+
+ // Strip leading & trailing spaces and control characters.
+ int scheme_begin = 0;
+ TrimURL(spec, &scheme_begin, &spec_len, trim_path_end);
+
+ // Handle empty specs or ones that contain only whitespace or control chars.
+ if (scheme_begin == spec_len) {
+ parsed->scheme.reset();
+ parsed->path.reset();
+ return;
+ }
+
+ int path_begin;
+ // Extract the scheme, with the path being everything following. We also
+ // handle the case where there is no scheme.
+ if (ExtractScheme(&spec[scheme_begin], spec_len - scheme_begin,
+ &parsed->scheme)) {
+ // Offset the results since we gave ExtractScheme a substring.
+ parsed->scheme.begin += scheme_begin;
+ path_begin = parsed->scheme.end() + 1;
+ } else {
+ // No scheme case.
+ parsed->scheme.reset();
+ path_begin = scheme_begin;
+ }
+
+ if (path_begin == spec_len)
+ return;
+ assert(path_begin < spec_len);
+
+ ParsePath(spec, MakeRange(path_begin, spec_len), &parsed->path,
+ &parsed->query, &parsed->ref);
+}
+
+void DoParseMailtoURL(const char* spec, int spec_len, Parsed* parsed) {
+ assert(spec_len >= 0);
+
+ // Get the non-path and non-scheme parts of the URL out of the way, we never
+ // use them.
+ parsed->username.reset();
+ parsed->password.reset();
+ parsed->host.reset();
+ parsed->port.reset();
+ parsed->ref.reset();
+ parsed->query.reset(); // May use this; reset for convenience.
+
+ // Strip leading & trailing spaces and control characters.
+ int begin = 0;
+ TrimURL(spec, &begin, &spec_len);
+
+ // Handle empty specs or ones that contain only whitespace or control chars.
+ if (begin == spec_len) {
+ parsed->scheme.reset();
+ parsed->path.reset();
+ return;
+ }
+
+ int path_begin = -1;
+ int path_end = -1;
+
+ // Extract the scheme, with the path being everything following. We also
+ // handle the case where there is no scheme.
+ if (ExtractScheme(&spec[begin], spec_len - begin, &parsed->scheme)) {
+ // Offset the results since we gave ExtractScheme a substring.
+ parsed->scheme.begin += begin;
+
+ if (parsed->scheme.end() != spec_len - 1) {
+ path_begin = parsed->scheme.end() + 1;
+ path_end = spec_len;
+ }
+ } else {
+ // No scheme found, just path.
+ parsed->scheme.reset();
+ path_begin = begin;
+ path_end = spec_len;
+ }
+
+ // Split [path_begin, path_end) into a path + query.
+ for (int i = path_begin; i < path_end; ++i) {
+ if (spec[i] == '?') {
+ parsed->query = MakeRange(i + 1, path_end);
+ path_end = i;
+ break;
+ }
+ }
+
+ // For compatability with the standard URL parser, treat no path as
+ // -1, rather than having a length of 0
+ if (path_begin == path_end) {
+ parsed->path.reset();
+ } else {
+ parsed->path = MakeRange(path_begin, path_end);
+ }
+}
+
+// Converts a port number in a string to an integer. We'd like to just call
+// sscanf but our input is not NULL-terminated, which sscanf requires. Instead,
+// we copy the digits to a small stack buffer (since we know the maximum number
+// of digits in a valid port number) that we can NULL terminate.
+int DoParsePort(const char* spec, const Component& component) {
+ // Easy success case when there is no port.
+ const int kMaxDigits = 5;
+ if (!component.is_nonempty())
+ return PORT_UNSPECIFIED;
+
+ // Skip over any leading 0s.
+ Component digits_comp(component.end(), 0);
+ for (int i = 0; i < component.len; i++) {
+ if (spec[component.begin + i] != '0') {
+ digits_comp = MakeRange(component.begin + i, component.end());
+ break;
+ }
+ }
+ if (digits_comp.len == 0)
+ return 0; // All digits were 0.
+
+ // Verify we don't have too many digits (we'll be copying to our buffer so
+ // we need to double-check).
+ if (digits_comp.len > kMaxDigits)
+ return PORT_INVALID;
+
+ // Copy valid digits to the buffer.
+ char digits[kMaxDigits + 1]; // +1 for null terminator
+ for (int i = 0; i < digits_comp.len; i++) {
+ char ch = spec[digits_comp.begin + i];
+ if (!IsPortDigit(ch)) {
+ // Invalid port digit, fail.
+ return PORT_INVALID;
+ }
+ digits[i] = static_cast<char>(ch);
+ }
+
+ // Null-terminate the string and convert to integer. Since we guarantee
+ // only digits, atoi's lack of error handling is OK.
+ digits[digits_comp.len] = 0;
+ int port = atoi(digits);
+ if (port > 65535)
+ return PORT_INVALID; // Out of range.
+ return port;
+}
+
+void DoExtractFileName(const char* spec,
+ const Component& path,
+ Component* file_name) {
+ // Handle empty paths: they have no file names.
+ if (!path.is_nonempty()) {
+ file_name->reset();
+ return;
+ }
+
+ // Extract the filename range from the path which is between
+ // the last slash and the following semicolon.
+ int file_end = path.end();
+ for (int i = path.end() - 1; i >= path.begin; i--) {
+ if (spec[i] == ';') {
+ file_end = i;
+ } else if (IsURLSlash(spec[i])) {
+ // File name is everything following this character to the end
+ *file_name = MakeRange(i + 1, file_end);
+ return;
+ }
+ }
+
+ // No slash found, this means the input was degenerate (generally paths
+ // will start with a slash). Let's call everything the file name.
+ *file_name = MakeRange(path.begin, file_end);
+ return;
+}
+
+bool DoExtractQueryKeyValue(const char* spec,
+ Component* query,
+ Component* key,
+ Component* value) {
+ if (!query->is_nonempty())
+ return false;
+
+ int start = query->begin;
+ int cur = start;
+ int end = query->end();
+
+ // We assume the beginning of the input is the beginning of the "key" and we
+ // skip to the end of it.
+ key->begin = cur;
+ while (cur < end && spec[cur] != '&' && spec[cur] != '=')
+ cur++;
+ key->len = cur - key->begin;
+
+ // Skip the separator after the key (if any).
+ if (cur < end && spec[cur] == '=')
+ cur++;
+
+ // Find the value part.
+ value->begin = cur;
+ while (cur < end && spec[cur] != '&')
+ cur++;
+ value->len = cur - value->begin;
+
+ // Finally skip the next separator if any
+ if (cur < end && spec[cur] == '&')
+ cur++;
+
+ // Save the new query
+ *query = MakeRange(cur, end);
+ return true;
+}
+
+} // namespace
+
+Parsed::Parsed() : potentially_dangling_markup(false), inner_parsed_(NULL) {}
+
+Parsed::Parsed(const Parsed& other)
+ : scheme(other.scheme),
+ username(other.username),
+ password(other.password),
+ host(other.host),
+ port(other.port),
+ path(other.path),
+ query(other.query),
+ ref(other.ref),
+ potentially_dangling_markup(other.potentially_dangling_markup),
+ inner_parsed_(NULL) {
+ if (other.inner_parsed_)
+ set_inner_parsed(*other.inner_parsed_);
+}
+
+Parsed& Parsed::operator=(const Parsed& other) {
+ if (this != &other) {
+ scheme = other.scheme;
+ username = other.username;
+ password = other.password;
+ host = other.host;
+ port = other.port;
+ path = other.path;
+ query = other.query;
+ ref = other.ref;
+ potentially_dangling_markup = other.potentially_dangling_markup;
+ if (other.inner_parsed_)
+ set_inner_parsed(*other.inner_parsed_);
+ else
+ clear_inner_parsed();
+ }
+ return *this;
+}
+
+Parsed::~Parsed() {
+ delete inner_parsed_;
+}
+
+int Parsed::Length() const {
+ if (ref.is_valid())
+ return ref.end();
+ return CountCharactersBefore(REF, false);
+}
+
+int Parsed::CountCharactersBefore(ComponentType type,
+ bool include_delimiter) const {
+ if (type == SCHEME)
+ return scheme.begin;
+
+ // There will be some characters after the scheme like "://" and we don't
+ // know how many. Search forwards for the next thing until we find one.
+ int cur = 0;
+ if (scheme.is_valid())
+ cur = scheme.end() + 1; // Advance over the ':' at the end of the scheme.
+
+ if (username.is_valid()) {
+ if (type <= USERNAME)
+ return username.begin;
+ cur = username.end() + 1; // Advance over the '@' or ':' at the end.
+ }
+
+ if (password.is_valid()) {
+ if (type <= PASSWORD)
+ return password.begin;
+ cur = password.end() + 1; // Advance over the '@' at the end.
+ }
+
+ if (host.is_valid()) {
+ if (type <= HOST)
+ return host.begin;
+ cur = host.end();
+ }
+
+ if (port.is_valid()) {
+ if (type < PORT || (type == PORT && include_delimiter))
+ return port.begin - 1; // Back over delimiter.
+ if (type == PORT)
+ return port.begin; // Don't want delimiter counted.
+ cur = port.end();
+ }
+
+ if (path.is_valid()) {
+ if (type <= PATH)
+ return path.begin;
+ cur = path.end();
+ }
+
+ if (query.is_valid()) {
+ if (type < QUERY || (type == QUERY && include_delimiter))
+ return query.begin - 1; // Back over delimiter.
+ if (type == QUERY)
+ return query.begin; // Don't want delimiter counted.
+ cur = query.end();
+ }
+
+ if (ref.is_valid()) {
+ if (type == REF && !include_delimiter)
+ return ref.begin; // Back over delimiter.
+
+ // When there is a ref and we get here, the component we wanted was before
+ // this and not found, so we always know the beginning of the ref is right.
+ return ref.begin - 1; // Don't want delimiter counted.
+ }
+
+ return cur;
+}
+
+Component Parsed::GetContent() const {
+ const int begin = CountCharactersBefore(USERNAME, false);
+ const int len = Length() - begin;
+ // For compatability with the standard URL parser, we treat no content as
+ // -1, rather than having a length of 0 (we normally wouldn't care so
+ // much for these non-standard URLs).
+ return len ? Component(begin, len) : Component();
+}
+
+bool ExtractScheme(const char* url, int url_len, Component* scheme) {
+ return DoExtractScheme(url, url_len, scheme);
+}
+
+// This handles everything that may be an authority terminator, including
+// backslash. For special backslash handling see DoParseAfterScheme.
+bool IsAuthorityTerminator(char ch) {
+ return IsURLSlash(ch) || ch == '?' || ch == '#';
+}
+
+void ExtractFileName(const char* url,
+ const Component& path,
+ Component* file_name) {
+ DoExtractFileName(url, path, file_name);
+}
+
+bool ExtractQueryKeyValue(const char* url,
+ Component* query,
+ Component* key,
+ Component* value) {
+ return DoExtractQueryKeyValue(url, query, key, value);
+}
+
+void ParseAuthority(const char* spec,
+ const Component& auth,
+ Component* username,
+ Component* password,
+ Component* hostname,
+ Component* port_num) {
+ DoParseAuthority(spec, auth, username, password, hostname, port_num);
+}
+
+int ParsePort(const char* url, const Component& port) {
+ return DoParsePort(url, port);
+}
+
+void ParseStandardURL(const char* url, int url_len, Parsed* parsed) {
+ DoParseStandardURL(url, url_len, parsed);
+}
+
+void ParsePathURL(const char* url,
+ int url_len,
+ bool trim_path_end,
+ Parsed* parsed) {
+ DoParsePathURL(url, url_len, trim_path_end, parsed);
+}
+
+void ParseFileSystemURL(const char* url, int url_len, Parsed* parsed) {
+ DoParseFileSystemURL(url, url_len, parsed);
+}
+
+void ParseMailtoURL(const char* url, int url_len, Parsed* parsed) {
+ DoParseMailtoURL(url, url_len, parsed);
+}
+
+void ParsePathInternal(const char* spec,
+ const Component& path,
+ Component* filepath,
+ Component* query,
+ Component* ref) {
+ ParsePath(spec, path, filepath, query, ref);
+}
+
+void ParseAfterScheme(const char* spec,
+ int spec_len,
+ int after_scheme,
+ Parsed* parsed) {
+ DoParseAfterScheme(spec, spec_len, after_scheme, parsed);
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h
new file mode 100644
index 00000000000..70e97adf9b8
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse.h
@@ -0,0 +1,322 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_MOZILLA_URL_PARSE_H_
+#define THIRD_PARTY_MOZILLA_URL_PARSE_H_
+
+namespace openscreen {
+
+// Component ------------------------------------------------------------------
+
+// Represents a substring for URL parsing.
+struct Component {
+ Component() : begin(0), len(-1) {}
+
+ // Normal constructor: takes an offset and a length.
+ Component(int b, int l) : begin(b), len(l) {}
+
+ int end() const { return begin + len; }
+
+ // Returns true if this component is valid, meaning the length is given. Even
+ // valid components may be empty to record the fact that they exist.
+ bool is_valid() const { return (len != -1); }
+
+ // Returns true if the given component is specified on false, the component
+ // is either empty or invalid.
+ bool is_nonempty() const { return (len > 0); }
+
+ void reset() {
+ begin = 0;
+ len = -1;
+ }
+
+ bool operator==(const Component& other) const {
+ return begin == other.begin && len == other.len;
+ }
+
+ int begin; // Byte offset in the string of this component.
+ int len; // Will be -1 if the component is unspecified.
+};
+
+// Helper that returns a component created with the given begin and ending
+// points. The ending point is non-inclusive.
+inline Component MakeRange(int begin, int end) {
+ return Component(begin, end - begin);
+}
+
+// Parsed ---------------------------------------------------------------------
+
+// A structure that holds the identified parts of an input URL. This structure
+// does NOT store the URL itself. The caller will have to store the URL text
+// and its corresponding Parsed structure separately.
+//
+// Typical usage would be:
+//
+// Parsed parsed;
+// Component scheme;
+// if (!ExtractScheme(url, url_len, &scheme))
+// return I_CAN_NOT_FIND_THE_SCHEME_DUDE;
+//
+// if (IsStandardScheme(url, scheme)) // Not provided by this component
+// ParseStandardURL(url, url_len, &parsed);
+// else if (IsFileURL(url, scheme)) // Not provided by this component
+// ParseFileURL(url, url_len, &parsed);
+// else
+// ParsePathURL(url, url_len, &parsed);
+//
+struct Parsed {
+ // Identifies different components.
+ enum ComponentType {
+ SCHEME,
+ USERNAME,
+ PASSWORD,
+ HOST,
+ PORT,
+ PATH,
+ QUERY,
+ REF,
+ };
+
+ // The default constructor is sufficient for the components, but inner_parsed_
+ // requires special handling.
+ Parsed();
+ Parsed(const Parsed&);
+ Parsed& operator=(const Parsed&);
+ ~Parsed();
+
+ // Returns the length of the URL (the end of the last component).
+ //
+ // Note that for some invalid, non-canonical URLs, this may not be the length
+ // of the string. For example "http://": the parsed structure will only
+ // contain an entry for the four-character scheme, and it doesn't know about
+ // the "://". For all other last-components, it will return the real length.
+ int Length() const;
+
+ // Returns the number of characters before the given component if it exists,
+ // or where the component would be if it did exist. This will return the
+ // string length if the component would be appended to the end.
+ //
+ // Note that this can get a little funny for the port, query, and ref
+ // components which have a delimiter that is not counted as part of the
+ // component. The |include_delimiter| flag controls if you want this counted
+ // as part of the component or not when the component exists.
+ //
+ // This example shows the difference between the two flags for two of these
+ // delimited components that is present (the port and query) and one that
+ // isn't (the reference). The components that this flag affects are marked
+ // with a *.
+ // 0 1 2
+ // 012345678901234567890
+ // Example input: http://foo:80/?query
+ // include_delim=true, ...=false ("<-" indicates different)
+ // SCHEME: 0 0
+ // USERNAME: 5 5
+ // PASSWORD: 5 5
+ // HOST: 7 7
+ // *PORT: 10 11 <-
+ // PATH: 13 13
+ // *QUERY: 14 15 <-
+ // *REF: 20 20
+ //
+ int CountCharactersBefore(ComponentType type, bool include_delimiter) const;
+
+ // Scheme without the colon: "http://foo"/ would have a scheme of "http".
+ // The length will be -1 if no scheme is specified ("foo.com"), or 0 if there
+ // is a colon but no scheme (":foo"). Note that the scheme is not guaranteed
+ // to start at the beginning of the string if there are preceeding whitespace
+ // or control characters.
+ Component scheme;
+
+ // Username. Specified in URLs with an @ sign before the host. See |password|
+ Component username;
+
+ // Password. The length will be -1 if unspecified, 0 if specified but empty.
+ // Not all URLs with a username have a password, as in "http://me@host/".
+ // The password is separated form the username with a colon, as in
+ // "http://me:secret@host/"
+ Component password;
+
+ // Host name.
+ Component host;
+
+ // Port number.
+ Component port;
+
+ // Path, this is everything following the host name, stopping at the query of
+ // ref delimiter (if any). Length will be -1 if unspecified. This includes
+ // the preceeding slash, so the path on http://www.google.com/asdf" is
+ // "/asdf". As a result, it is impossible to have a 0 length path, it will
+ // be -1 in cases like "http://host?foo".
+ // Note that we treat backslashes the same as slashes.
+ Component path;
+
+ // Stuff between the ? and the # after the path. This does not include the
+ // preceeding ? character. Length will be -1 if unspecified, 0 if there is
+ // a question mark but no query string.
+ Component query;
+
+ // Indicated by a #, this is everything following the hash sign (not
+ // including it). If there are multiple hash signs, we'll use the last one.
+ // Length will be -1 if there is no hash sign, or 0 if there is one but
+ // nothing follows it.
+ Component ref;
+
+ // The URL spec from the character after the scheme: until the end of the
+ // URL, regardless of the scheme. This is mostly useful for 'opaque' non-
+ // hierarchical schemes like data: and javascript: as a convient way to get
+ // the string with the scheme stripped off.
+ Component GetContent() const;
+
+ // True if the URL's source contained a raw `<` character, and whitespace was
+ // removed from the URL during parsing
+ //
+ // TODO(mkwst): Link this to something in a spec if
+ // https://github.com/whatwg/url/pull/284 lands.
+ bool potentially_dangling_markup;
+
+ // This is used for nested URL types, currently only filesystem. If you
+ // parse a filesystem URL, the resulting Parsed will have a nested
+ // inner_parsed_ to hold the parsed inner URL's component information.
+ // For all other url types [including the inner URL], it will be NULL.
+ Parsed* inner_parsed() const { return inner_parsed_; }
+
+ void set_inner_parsed(const Parsed& inner_parsed) {
+ if (!inner_parsed_)
+ inner_parsed_ = new Parsed(inner_parsed);
+ else
+ *inner_parsed_ = inner_parsed;
+ }
+
+ void clear_inner_parsed() {
+ if (inner_parsed_) {
+ delete inner_parsed_;
+ inner_parsed_ = nullptr;
+ }
+ }
+
+ private:
+ Parsed* inner_parsed_; // This object is owned and managed by this struct.
+};
+
+// Initialization functions ---------------------------------------------------
+//
+// These functions parse the given URL, filling in all of the structure's
+// components. These functions can not fail, they will always do their best
+// at interpreting the input given.
+//
+// The string length of the URL MUST be specified, we do not check for NULLs
+// at any point in the process, and will actually handle embedded NULLs.
+//
+// IMPORTANT: These functions do NOT hang on to the given pointer or copy it
+// in any way. See the comment above the struct.
+//
+// The 8-bit versions require UTF-8 encoding.
+
+// StandardURL is for when the scheme is known to be one that has an
+// authority (host) like "http". This function will not handle weird ones
+// like "about:" and "javascript:", or do the right thing for "file:" URLs.
+void ParseStandardURL(const char* url, int url_len, Parsed* parsed);
+
+// PathURL is for when the scheme is known not to have an authority (host)
+// section but that aren't file URLs either. The scheme is parsed, and
+// everything after the scheme is considered as the path. This is used for
+// things like "about:" and "javascript:"
+void ParsePathURL(const char* url,
+ int url_len,
+ bool trim_path_end,
+ Parsed* parsed);
+
+// FileURL is for file URLs. There are some special rules for interpreting
+// these.
+void ParseFileURL(const char* url, int url_len, Parsed* parsed);
+
+// Filesystem URLs are structured differently than other URLs.
+void ParseFileSystemURL(const char* url, int url_len, Parsed* parsed);
+
+// MailtoURL is for mailto: urls. They are made up scheme,path,query
+void ParseMailtoURL(const char* url, int url_len, Parsed* parsed);
+
+// Helper functions -----------------------------------------------------------
+
+// Locates the scheme according to the URL parser's rules. This function is
+// designed so the caller can find the scheme and call the correct Init*
+// function according to their known scheme types.
+//
+// It also does not perform any validation on the scheme.
+//
+// This function will return true if the scheme is found and will put the
+// scheme's range into *scheme. False means no scheme could be found. Note
+// that a URL beginning with a colon has a scheme, but it is empty, so this
+// function will return true but *scheme will = (0,0).
+//
+// The scheme is found by skipping spaces and control characters at the
+// beginning, and taking everything from there to the first colon to be the
+// scheme. The character at scheme.end() will be the colon (we may enhance
+// this to handle full width colons or something, so don't count on the
+// actual character value). The character at scheme.end()+1 will be the
+// beginning of the rest of the URL, be it the authority or the path (or the
+// end of the string).
+//
+// The 8-bit version requires UTF-8 encoding.
+bool ExtractScheme(const char* url, int url_len, Component* scheme);
+
+// Returns true if ch is a character that terminates the authority segment
+// of a URL.
+bool IsAuthorityTerminator(char ch);
+
+// Does a best effort parse of input |spec|, in range |auth|. If a particular
+// component is not found, it will be set to invalid.
+void ParseAuthority(const char* spec,
+ const Component& auth,
+ Component* username,
+ Component* password,
+ Component* hostname,
+ Component* port_num);
+
+// Computes the integer port value from the given port component. The port
+// component should have been identified by one of the init functions on
+// |Parsed| for the given input url.
+//
+// The return value will be a positive integer between 0 and 64K, or one of
+// the two special values below.
+enum SpecialPort { PORT_UNSPECIFIED = -1, PORT_INVALID = -2 };
+int ParsePort(const char* url, const Component& port);
+
+// Extracts the range of the file name in the given url. The path must
+// already have been computed by the parse function, and the matching URL
+// and extracted path are provided to this function. The filename is
+// defined as being everything from the last slash/backslash of the path
+// to the end of the path.
+//
+// The file name will be empty if the path is empty or there is nothing
+// following the last slash.
+//
+// The 8-bit version requires UTF-8 encoding.
+void ExtractFileName(const char* url,
+ const Component& path,
+ Component* file_name);
+
+// Extract the first key/value from the range defined by |*query|. Updates
+// |*query| to start at the end of the extracted key/value pair. This is
+// designed for use in a loop: you can keep calling it with the same query
+// object and it will iterate over all items in the query.
+//
+// Some key/value pairs may have the key, the value, or both be empty (for
+// example, the query string "?&"). These will be returned. Note that an empty
+// last parameter "foo.com?" or foo.com?a&" will not be returned, this case
+// is the same as "done."
+//
+// The initial query component should not include the '?' (this is the default
+// for parsed URLs).
+//
+// If no key/value are found |*key| and |*value| will be unchanged and it will
+// return false.
+bool ExtractQueryKeyValue(const char* url,
+ Component* query,
+ Component* key,
+ Component* value);
+
+} // namespace openscreen
+
+#endif // THIRD_PARTY_MOZILLA_URL_PARSE_H_
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc
new file mode 100644
index 00000000000..136bc62d34b
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.cc
@@ -0,0 +1,87 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "third_party/mozilla//url_parse_internal.h"
+
+#include <ctype.h>
+
+#include "third_party/mozilla/url_parse.h"
+
+namespace openscreen {
+
+namespace {
+
+static const char* g_standard_schemes[] = {
+ kHttpsScheme, kHttpScheme, kFileScheme, kFtpScheme,
+ kWssScheme, kWsScheme, kFileSystemScheme,
+};
+
+} // namespace
+
+bool IsURLSlash(char ch) {
+ return ch == '/' || ch == '\\';
+}
+
+bool ShouldTrimFromURL(char ch) {
+ return ch <= ' ';
+}
+
+void TrimURL(const char* spec, int* begin, int* len, bool trim_path_end) {
+ // Strip leading whitespace and control characters.
+ while (*begin < *len && ShouldTrimFromURL(spec[*begin])) {
+ (*begin)++;
+ }
+
+ if (trim_path_end) {
+ // Strip trailing whitespace and control characters. We need the >i test
+ // for when the input string is all blanks; we don't want to back past the
+ // input.
+ while (*len > *begin && ShouldTrimFromURL(spec[*len - 1])) {
+ (*len)--;
+ }
+ }
+}
+
+int CountConsecutiveSlashes(const char* str, int begin_offset, int str_len) {
+ int count = 0;
+ while ((begin_offset + count) < str_len &&
+ IsURLSlash(str[begin_offset + count])) {
+ ++count;
+ }
+ return count;
+}
+
+bool CompareSchemeComponent(const char* spec,
+ const Component& component,
+ const char* compare_to) {
+ if (!component.is_nonempty()) {
+ return compare_to[0] == 0; // When component is empty, match empty scheme.
+ }
+ for (int i = 0; i < component.len; ++i) {
+ if (tolower(spec[i]) != compare_to[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool IsStandard(const char* spec, const Component& component) {
+ if (!component.is_nonempty()) {
+ return false;
+ }
+
+ constexpr int scheme_count =
+ sizeof(g_standard_schemes) / sizeof(g_standard_schemes[0]);
+ for (int i = 0; i < scheme_count; ++i) {
+ if (CompareSchemeComponent(spec, component, g_standard_schemes[i])) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// NOTE: Not implemented because file URLs are currently unsupported.
+void ParseFileURL(const char* url, int url_len, Parsed* parsed) {}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h
new file mode 100644
index 00000000000..58f9f75bc74
--- /dev/null
+++ b/chromium/third_party/openscreen/src/third_party/mozilla/url_parse_internal.h
@@ -0,0 +1,50 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_
+#define THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_
+
+namespace openscreen {
+
+struct Component;
+
+static constexpr char kHttpsScheme[] = "https";
+static constexpr char kHttpScheme[] = "http";
+static constexpr char kFileScheme[] = "file";
+static constexpr char kFtpScheme[] = "ftp";
+static constexpr char kWssScheme[] = "wss";
+static constexpr char kWsScheme[] = "ws";
+static constexpr char kFileSystemScheme[] = "filesystem";
+static constexpr char kMailtoScheme[] = "mailto";
+
+// Returns whether the character |ch| should be treated as a slash.
+bool IsURLSlash(char ch);
+
+// Returns whether the character |ch| can be safely removed for the URL.
+bool ShouldTrimFromURL(char ch);
+
+// Given an already-initialized begin index and length, this shrinks the range
+// to eliminate "should-be-trimmed" characters. Note that the length does *not*
+// indicate the length of untrimmed data from |*begin|, but rather the position
+// in the input string (so the string starts at character |*begin| in the spec,
+// and goes until |*len|).
+void TrimURL(const char* spec, int* begin, int* len, bool trim_path_end = true);
+
+// Returns the number of consecutive slashes in |str| starting from offset
+// |begin_offset|.
+int CountConsecutiveSlashes(const char* str, int begin_offset, int str_len);
+
+// Given a string and a range inside the string, compares it to the given
+// lower-case |compare_to| buffer.
+bool CompareSchemeComponent(const char* spec,
+ const Component& component,
+ const char* compare_to);
+
+// Returns whether the scheme given by (spec, component) is a standard scheme
+// (i.e. https://url.spec.whatwg.org/#special-scheme).
+bool IsStandard(const char* spec, const Component& component);
+
+} // namespace openscreen
+
+#endif // THIRD_PARTY_MOZILLA_URL_PARSE_INTERNAL_H_
diff --git a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn
index d11dd5277af..787d47ce0a0 100644
--- a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn
@@ -2,6 +2,8 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
+import("//build_overrides/build.gni")
+
config("protobuf_config") {
include_dirs = [ "src/src" ]
defines = [
diff --git a/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn b/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn
index ff8480fcba8..4521580544c 100644
--- a/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/tinycbor/BUILD.gn
@@ -4,6 +4,11 @@
import("//build_overrides/build.gni")
+config("tinycbor_internal_config") {
+ defines = [ "WITHOUT_OPEN_MEMSTREAM" ]
+ cflags = [ "-w" ] # Disable all warnings.
+}
+
source_set("tinycbor") {
sources = [
"src/src/cbor.h",
@@ -16,5 +21,5 @@ source_set("tinycbor") {
"src/src/utf8_p.h",
]
- defines = [ "WITHOUT_OPEN_MEMSTREAM" ]
+ configs += [ ":tinycbor_internal_config" ]
}
diff --git a/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn b/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn
index 29e6004b084..457fc9a8948 100644
--- a/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn
+++ b/chromium/third_party/openscreen/src/third_party/zlib/BUILD.gn
@@ -2,7 +2,19 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
-source_set("zlib") {
+config("zlib_config") {
+ include_dirs = [ "src" ]
+}
+
+config("zlib_internal_config") {
+ defines = [ "ZLIB_IMPLEMENTATION" ]
+ cflags = [ "-w" ] # Disable all warnings.
+}
+
+static_library("zlib") {
+ # Don't stomp on "libzlib"
+ output_name = "chrome_zlib"
+
sources = [
"src/adler32.c",
"src/compress.c",
@@ -26,30 +38,19 @@ source_set("zlib") {
"src/trees.c",
"src/trees.h",
"src/uncompr.c",
+ "src/zconf.h",
"src/zlib.h",
"src/zutil.c",
"src/zutil.h",
]
- include_dirs = [
- "src",
- ".",
- ]
+ defines = []
+ deps = []
- defines = [
- "HAVE_SYS_TYPES_H",
- "HAVE_STDINT_H",
- "HAVE_STDDEF_H",
- "_LARGEFILE64_SOURCE",
- "_FILE_OFFSET_BITS=64",
- ]
- if (is_mac) {
- # NOTE: zlib wants to use fseeko() if available, but on Mac it isn't
- # available.
- defines += [ "NO_FSEEKO" ]
- }
-}
+ include_dirs = [ "." ]
+ configs += [ ":zlib_internal_config" ]
-config("zlib_config") {
- include_dirs = [ "src" ]
+ public_configs = [ ":zlib_config" ]
+
+ allow_circular_includes_from = deps
}
diff --git a/chromium/third_party/openscreen/src/tools/cl-format.sh b/chromium/third_party/openscreen/src/tools/cl-format.sh
deleted file mode 100755
index 6bc159e6953..00000000000
--- a/chromium/third_party/openscreen/src/tools/cl-format.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 The Chromium Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-
-for f in $(git diff --name-only @{u}); do
- # Skip third party files, except our custom BUILD.gns
- if [[ $f =~ third_party/[^\/]*/src ]]; then
- continue;
- fi
-
- # Skip statically copied Chromium QUIC build files.
- if [[ $f =~ third_party/chromium_quic/build ]]; then
- continue;
- fi
-
- # Skip files deleted in this patch
- if ! [[ -f $f ]]; then
- continue;
- fi
-
- # Format cpp files
- if [[ $f =~ \.(cc|h)$ ]]; then
- clang-format -style=file -i "$f"
- fi
-
- # Format gn files
- if [[ $f =~ \.gn$ ]]; then
- gn format $f
- fi
-
-done
diff --git a/chromium/third_party/openscreen/src/tools/clang/scripts/update.py b/chromium/third_party/openscreen/src/tools/clang/scripts/update.py
deleted file mode 100755
index 3774d78deac..00000000000
--- a/chromium/third_party/openscreen/src/tools/clang/scripts/update.py
+++ /dev/null
@@ -1,1025 +0,0 @@
-#!/usr/bin/env python
-# Copyright (c) 2012 The Chromium Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-
-# NOTE: this file is taken from the chromium repo, viewable here:
-# https://cs.chromium.org/codesearch/f/chromium/src/tools/clang/scripts/update.py?cl=e485df98abd46b3f774377624ae7850e795fd4d8
-
-# Some minor alterations have been made to suit our use case here.
-
-"""This script is used to download prebuilt clang binaries.
-
-It is also used by package.py to build the prebuilt clang binaries."""
-
-import argparse
-import distutils.spawn
-import glob
-import os
-import pipes
-import re
-import shutil
-import subprocess
-import stat
-import sys
-import tarfile
-import tempfile
-import time
-import urllib2
-import zipfile
-
-
-# Do NOT CHANGE this if you don't know what you're doing -- see
-# https://chromium.googlesource.com/chromium/src/+/master/docs/updating_clang.md
-# Reverting problematic clang rolls is safe, though.
-CLANG_REVISION = '352138'
-
-use_head_revision = bool(os.environ.get('LLVM_FORCE_HEAD_REVISION', '0')
- in ('1', 'YES'))
-if use_head_revision:
- CLANG_REVISION = 'HEAD'
-
-# This is incremented when pushing a new build of Clang at the same revision.
-CLANG_SUB_REVISION=2
-
-PACKAGE_VERSION = "%s-%s" % (CLANG_REVISION, CLANG_SUB_REVISION)
-
-# Path constants. (All of these should be absolute paths.)
-THIS_DIR = os.path.abspath(os.path.dirname(__file__))
-CHROMIUM_DIR = os.path.abspath(os.path.join(THIS_DIR, '..', '..', '..'))
-GCLIENT_CONFIG = os.path.join(os.path.dirname(CHROMIUM_DIR), '.gclient')
-THIRD_PARTY_DIR = os.path.join(CHROMIUM_DIR, 'third_party')
-LLVM_DIR = os.path.join(THIRD_PARTY_DIR, 'llvm')
-LLVM_BOOTSTRAP_DIR = os.path.join(THIRD_PARTY_DIR, 'llvm-bootstrap')
-LLVM_BOOTSTRAP_INSTALL_DIR = os.path.join(THIRD_PARTY_DIR,
- 'llvm-bootstrap-install')
-CHROME_TOOLS_SHIM_DIR = os.path.join(LLVM_DIR, 'tools', 'chrometools')
-LLVM_BUILD_DIR = os.path.join(CHROMIUM_DIR, 'third_party', 'llvm-build',
- 'Release+Asserts')
-THREADS_ENABLED_BUILD_DIR = os.path.join(LLVM_BUILD_DIR, 'threads_enabled')
-COMPILER_RT_BUILD_DIR = os.path.join(LLVM_BUILD_DIR, 'compiler-rt')
-CLANG_DIR = os.path.join(LLVM_DIR, 'tools', 'clang')
-LLD_DIR = os.path.join(LLVM_DIR, 'tools', 'lld')
-# compiler-rt is built as part of the regular LLVM build on Windows to get
-# the 64-bit runtime, and out-of-tree elsewhere.
-# TODO(thakis): Try to unify this.
-if sys.platform == 'win32':
- COMPILER_RT_DIR = os.path.join(LLVM_DIR, 'projects', 'compiler-rt')
-else:
- COMPILER_RT_DIR = os.path.join(LLVM_DIR, 'compiler-rt')
-LIBCXX_DIR = os.path.join(LLVM_DIR, 'projects', 'libcxx')
-LIBCXXABI_DIR = os.path.join(LLVM_DIR, 'projects', 'libcxxabi')
-LLVM_BUILD_TOOLS_DIR = os.path.abspath(
- os.path.join(LLVM_DIR, '..', 'llvm-build-tools'))
-STAMP_FILE = os.path.normpath(
- os.path.join(LLVM_DIR, '..', 'llvm-build', 'cr_build_revision'))
-VERSION = '9.0.0'
-ANDROID_NDK_DIR = os.path.join(
- CHROMIUM_DIR, 'third_party', 'android_ndk')
-FUCHSIA_SDK_DIR = os.path.join(CHROMIUM_DIR, 'third_party', 'fuchsia-sdk',
- 'sdk')
-
-# URL for pre-built binaries.
-CDS_URL = os.environ.get('CDS_CLANG_BUCKET_OVERRIDE',
- 'https://commondatastorage.googleapis.com/chromium-browser-clang')
-
-LLVM_REPO_URL='https://llvm.org/svn/llvm-project'
-if 'LLVM_REPO_URL' in os.environ:
- LLVM_REPO_URL = os.environ['LLVM_REPO_URL']
-
-
-
-def DownloadUrl(url, output_file):
- """Download url into output_file."""
- CHUNK_SIZE = 4096
- TOTAL_DOTS = 10
- num_retries = 3
- retry_wait_s = 5 # Doubled at each retry.
-
- while True:
- try:
- sys.stdout.write('Downloading %s ' % url)
- sys.stdout.write('to %s' % output_file)
- sys.stdout.flush()
- response = urllib2.urlopen(url)
- total_size = int(response.info().getheader('Content-Length').strip())
- bytes_done = 0
- dots_printed = 0
- while True:
- chunk = response.read(CHUNK_SIZE)
- if not chunk:
- break
- output_file.write(chunk)
- bytes_done += len(chunk)
- num_dots = TOTAL_DOTS * bytes_done / total_size
- sys.stdout.write('.' * (num_dots - dots_printed))
- sys.stdout.flush()
- dots_printed = num_dots
- if bytes_done != total_size:
- raise urllib2.URLError("only got %d of %d bytes" %
- (bytes_done, total_size))
- print ' Done.'
- return
- except urllib2.URLError as e:
- sys.stdout.write('\n')
- print e
- if num_retries == 0 or isinstance(e, urllib2.HTTPError) and e.code == 404:
- raise e
- num_retries -= 1
- print 'Retrying in %d s ...' % retry_wait_s
- time.sleep(retry_wait_s)
- retry_wait_s *= 2
-
-
-def EnsureDirExists(path):
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def DownloadAndUnpack(url, output_dir, path_prefix=None):
- """Download an archive from url and extract into output_dir. If path_prefix is
- not None, only extract files whose paths within the archive start with
- path_prefix."""
- with tempfile.TemporaryFile() as f:
- DownloadUrl(url, f)
- f.seek(0)
- EnsureDirExists(output_dir)
- if url.endswith('.zip'):
- assert path_prefix is None
- zipfile.ZipFile(f).extractall(path=output_dir)
- else:
- t = tarfile.open(mode='r:gz', fileobj=f)
- members = None
- if path_prefix is not None:
- members = [m for m in t.getmembers() if m.name.startswith(path_prefix)]
- print("Extracting to output_dir: {}".format(output_dir))
- t.extractall(path=output_dir, members=members)
-
-
-def ReadStampFile(path=STAMP_FILE):
- """Return the contents of the stamp file, or '' if it doesn't exist."""
- try:
- with open(path, 'r') as f:
- return f.read().rstrip()
- except IOError:
- return ''
-
-
-def WriteStampFile(s, path=STAMP_FILE):
- """Write s to the stamp file."""
- EnsureDirExists(os.path.dirname(path))
- with open(path, 'w') as f:
- f.write(s)
- f.write('\n')
-
-
-def GetSvnRevision(svn_repo):
- """Returns current revision of the svn repo at svn_repo."""
- svn_info = subprocess.check_output('svn info ' + svn_repo, shell=True)
- m = re.search(r'Revision: (\d+)', svn_info)
- return m.group(1)
-
-
-def RmTree(dir):
- """Delete dir."""
- def ChmodAndRetry(func, path, _):
- # Subversion can leave read-only files around.
- if not os.access(path, os.W_OK):
- os.chmod(path, stat.S_IWUSR)
- return func(path)
- raise
-
- shutil.rmtree(dir, onerror=ChmodAndRetry)
-
-
-def RmCmakeCache(dir):
- """Delete CMake cache related files from dir."""
- for dirpath, dirs, files in os.walk(dir):
- if 'CMakeCache.txt' in files:
- os.remove(os.path.join(dirpath, 'CMakeCache.txt'))
- if 'CMakeFiles' in dirs:
- RmTree(os.path.join(dirpath, 'CMakeFiles'))
-
-
-def RunCommand(command, msvc_arch=None, env=None, fail_hard=True):
- """Run command and return success (True) or failure; or if fail_hard is
- True, exit on failure. If msvc_arch is set, runs the command in a
- shell with the msvc tools for that architecture."""
-
- if msvc_arch and sys.platform == 'win32':
- command = [os.path.join(GetWinSDKDir(), 'bin', 'SetEnv.cmd'),
- "/" + msvc_arch, '&&'] + command
-
- # https://docs.python.org/2/library/subprocess.html:
- # "On Unix with shell=True [...] if args is a sequence, the first item
- # specifies the command string, and any additional items will be treated as
- # additional arguments to the shell itself. That is to say, Popen does the
- # equivalent of:
- # Popen(['/bin/sh', '-c', args[0], args[1], ...])"
- #
- # We want to pass additional arguments to command[0], not to the shell,
- # so manually join everything into a single string.
- # Annoyingly, for "svn co url c:\path", pipes.quote() thinks that it should
- # quote c:\path but svn can't handle quoted paths on Windows. Since on
- # Windows follow-on args are passed to args[0] instead of the shell, don't
- # do the single-string transformation there.
- if sys.platform != 'win32':
- command = ' '.join([pipes.quote(c) for c in command])
- print 'Running', command
- if subprocess.call(command, env=env, shell=True) == 0:
- return True
- print 'Failed.'
- if fail_hard:
- sys.exit(1)
- return False
-
-
-def CopyFile(src, dst):
- """Copy a file from src to dst."""
- print "Copying %s to %s" % (src, dst)
- shutil.copy(src, dst)
-
-
-def CopyDirectoryContents(src, dst):
- """Copy the files from directory src to dst."""
- dst = os.path.realpath(dst) # realpath() in case dst ends in /..
- EnsureDirExists(dst)
- for f in os.listdir(src):
- CopyFile(os.path.join(src, f), dst)
-
-
-def Checkout(name, url, dir):
- """Checkout the SVN module at url into dir. Use name for the log message."""
- print "Checking out %s r%s into '%s'" % (name, CLANG_REVISION, dir)
-
- command = ['svn', 'checkout', '--force', url + '@' + CLANG_REVISION, dir]
- if RunCommand(command, fail_hard=False):
- return
-
- if os.path.isdir(dir):
- print "Removing %s." % (dir)
- RmTree(dir)
-
- print "Retrying."
- RunCommand(command)
-
-
-def CheckoutRepos(args):
- if args.skip_checkout:
- return
-
- Checkout('LLVM', LLVM_REPO_URL + '/llvm/trunk', LLVM_DIR)
- Checkout('Clang', LLVM_REPO_URL + '/cfe/trunk', CLANG_DIR)
- if True:
- Checkout('LLD', LLVM_REPO_URL + '/lld/trunk', LLD_DIR)
- elif os.path.exists(LLD_DIR):
- # In case someone sends a tryjob that temporary adds lld to the checkout,
- # make sure it's not around on future builds.
- RmTree(LLD_DIR)
- Checkout('compiler-rt', LLVM_REPO_URL + '/compiler-rt/trunk', COMPILER_RT_DIR)
- if sys.platform == 'darwin':
- # clang needs a libc++ checkout, else -stdlib=libc++ won't find includes
- # (i.e. this is needed for bootstrap builds).
- Checkout('libcxx', LLVM_REPO_URL + '/libcxx/trunk', LIBCXX_DIR)
- # We used to check out libcxxabi on OS X; we no longer need that.
- if os.path.exists(LIBCXXABI_DIR):
- RmTree(LIBCXXABI_DIR)
-
-
-def DeleteChromeToolsShim():
- OLD_SHIM_DIR = os.path.join(LLVM_DIR, 'tools', 'zzz-chrometools')
- shutil.rmtree(OLD_SHIM_DIR, ignore_errors=True)
- shutil.rmtree(CHROME_TOOLS_SHIM_DIR, ignore_errors=True)
-
-
-def CreateChromeToolsShim():
- """Hooks the Chrome tools into the LLVM build.
-
- Several Chrome tools have dependencies on LLVM/Clang libraries. The LLVM build
- detects implicit tools in the tools subdirectory, so this helper install a
- shim CMakeLists.txt that forwards to the real directory for the Chrome tools.
-
- Note that the shim directory name intentionally has no - or _. The implicit
- tool detection logic munges them in a weird way."""
- assert not any(i in os.path.basename(CHROME_TOOLS_SHIM_DIR) for i in '-_')
- os.mkdir(CHROME_TOOLS_SHIM_DIR)
- with file(os.path.join(CHROME_TOOLS_SHIM_DIR, 'CMakeLists.txt'), 'w') as f:
- f.write('# Automatically generated by tools/clang/scripts/update.py. ' +
- 'Do not edit.\n')
- f.write('# Since tools/clang is located in another directory, use the \n')
- f.write('# two arg version to specify where build artifacts go. CMake\n')
- f.write('# disallows reuse of the same binary dir for multiple source\n')
- f.write('# dirs, so the build artifacts need to go into a subdirectory.\n')
- f.write('if (CHROMIUM_TOOLS_SRC)\n')
- f.write(' add_subdirectory(${CHROMIUM_TOOLS_SRC} ' +
- '${CMAKE_CURRENT_BINARY_DIR}/a)\n')
- f.write('endif (CHROMIUM_TOOLS_SRC)\n')
-
-
-def AddSvnToPathOnWin():
- """Download svn.exe and add it to PATH."""
- if sys.platform != 'win32':
- return
- svn_ver = 'svn-1.6.6-win'
- svn_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, svn_ver)
- if not os.path.exists(svn_dir):
- DownloadAndUnpack(CDS_URL + '/tools/%s.zip' % svn_ver, LLVM_BUILD_TOOLS_DIR)
- os.environ['PATH'] = svn_dir + os.pathsep + os.environ.get('PATH', '')
-
-
-def AddCMakeToPath(args):
- """Download CMake and add it to PATH."""
- if args.use_system_cmake:
- return
-
- if sys.platform == 'win32':
- zip_name = 'cmake-3.12.1-win32-x86.zip'
- dir_name = ['cmake-3.12.1-win32-x86', 'bin']
- elif sys.platform == 'darwin':
- zip_name = 'cmake-3.12.1-Darwin-x86_64.tar.gz'
- dir_name = ['cmake-3.12.1-Darwin-x86_64', 'CMake.app', 'Contents', 'bin']
- else:
- zip_name = 'cmake-3.12.1-Linux-x86_64.tar.gz'
- dir_name = ['cmake-3.12.1-Linux-x86_64', 'bin']
-
- cmake_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, *dir_name)
- if not os.path.exists(cmake_dir):
- DownloadAndUnpack(CDS_URL + '/tools/' + zip_name, LLVM_BUILD_TOOLS_DIR)
- os.environ['PATH'] = cmake_dir + os.pathsep + os.environ.get('PATH', '')
-
-
-def AddGnuWinToPath():
- """Download some GNU win tools and add them to PATH."""
- if sys.platform != 'win32':
- return
-
- gnuwin_dir = os.path.join(LLVM_BUILD_TOOLS_DIR, 'gnuwin')
- GNUWIN_VERSION = '9'
- GNUWIN_STAMP = os.path.join(gnuwin_dir, 'stamp')
- if ReadStampFile(GNUWIN_STAMP) == GNUWIN_VERSION:
- print 'GNU Win tools already up to date.'
- else:
- zip_name = 'gnuwin-%s.zip' % GNUWIN_VERSION
- DownloadAndUnpack(CDS_URL + '/tools/' + zip_name, LLVM_BUILD_TOOLS_DIR)
- WriteStampFile(GNUWIN_VERSION, GNUWIN_STAMP)
-
- os.environ['PATH'] = gnuwin_dir + os.pathsep + os.environ.get('PATH', '')
-
- # find.exe, mv.exe and rm.exe are from MSYS (see crrev.com/389632). MSYS uses
- # Cygwin under the hood, and initializing Cygwin has a race-condition when
- # getting group and user data from the Active Directory is slow. To work
- # around this, use a horrible hack telling it not to do that.
- # See https://crbug.com/905289
- etc = os.path.join(gnuwin_dir, '..', '..', 'etc')
- EnsureDirExists(etc)
- with open(os.path.join(etc, 'nsswitch.conf'), 'w') as f:
- f.write('passwd: files\n')
- f.write('group: files\n')
-
-
-win_sdk_dir = None
-dia_dll = None
-def GetWinSDKDir():
- """Get the location of the current SDK. Sets dia_dll as a side-effect."""
- global win_sdk_dir
- global dia_dll
- if win_sdk_dir:
- return win_sdk_dir
-
- # Bump after VC updates.
- DIA_DLL = {
- '2013': 'msdia120.dll',
- '2015': 'msdia140.dll',
- '2017': 'msdia140.dll',
- '2019': 'msdia140.dll',
- }
-
- # Don't let vs_toolchain overwrite our environment.
- environ_bak = os.environ
-
- sys.path.append(os.path.join(CHROMIUM_DIR, 'build'))
- import vs_toolchain
- win_sdk_dir = vs_toolchain.SetEnvironmentAndGetSDKDir()
- msvs_version = vs_toolchain.GetVisualStudioVersion()
-
- if bool(int(os.environ.get('DEPOT_TOOLS_WIN_TOOLCHAIN', '1'))):
- dia_path = os.path.join(win_sdk_dir, '..', 'DIA SDK', 'bin', 'amd64')
- else:
- if 'GYP_MSVS_OVERRIDE_PATH' not in os.environ:
- vs_path = vs_toolchain.DetectVisualStudioPath()
- else:
- vs_path = os.environ['GYP_MSVS_OVERRIDE_PATH']
- dia_path = os.path.join(vs_path, 'DIA SDK', 'bin', 'amd64')
-
- dia_dll = os.path.join(dia_path, DIA_DLL[msvs_version])
-
- os.environ = environ_bak
- return win_sdk_dir
-
-
-def CopyDiaDllTo(target_dir):
- # This script always wants to use the 64-bit msdia*.dll.
- GetWinSDKDir()
- CopyFile(dia_dll, target_dir)
-
-
-def VeryifyVersionOfBuiltClangMatchesVERSION():
- """Checks that `clang --version` outputs VERSION. If this fails, VERSION
- in this file is out-of-date and needs to be updated (possibly in an
- `if use_head_revision:` block in main() first)."""
- clang = os.path.join(LLVM_BUILD_DIR, 'bin', 'clang')
- if sys.platform == 'win32':
- clang += '.exe'
- version_out = subprocess.check_output([clang, '--version'])
- version_out = re.match(r'clang version ([0-9.]+)', version_out).group(1)
- if version_out != VERSION:
- print ('unexpected clang version %s (not %s), update VERSION in update.py'
- % (version_out, VERSION))
- sys.exit(1)
-
-
-def GetPlatformUrlPrefix(platform):
- if platform == 'win32' or platform == 'cygwin':
- return CDS_URL + '/Win/'
- if platform == 'darwin':
- return CDS_URL + '/Mac/'
- assert platform.startswith('linux')
- return CDS_URL + '/Linux_x64/'
-
-
-def DownloadAndUnpackClangPackage(platform, runtimes_only=False):
- cds_file = "clang-%s.tgz" % PACKAGE_VERSION
- cds_full_url = GetPlatformUrlPrefix(platform) + cds_file
- try:
- path_prefix = None
- if runtimes_only:
- path_prefix = 'lib/clang/' + VERSION + '/lib/'
- DownloadAndUnpack(cds_full_url, LLVM_BUILD_DIR, path_prefix)
- except urllib2.URLError:
- print 'Failed to download prebuilt clang %s' % cds_file
- print 'Use --force-local-build if you want to build locally.'
- print 'Exiting.'
- sys.exit(1)
-
-
-def UpdateClang(args):
- # Read target_os from .gclient so we know which non-native runtimes we need.
- # TODO(pcc): See if we can download just the runtimes instead of the entire
- # clang package, and do that from DEPS instead of here.
- target_os = []
- try:
- env = {}
- execfile(GCLIENT_CONFIG, env, env)
- target_os = env.get('target_os', target_os)
- except:
- pass
-
- expected_stamp = ','.join([PACKAGE_VERSION] + target_os)
- if ReadStampFile() == expected_stamp and not args.force_local_build:
- return 0
-
- # Reset the stamp file in case the build is unsuccessful.
- WriteStampFile('')
-
- if not args.force_local_build:
- if os.path.exists(LLVM_BUILD_DIR):
- RmTree(LLVM_BUILD_DIR)
-
- DownloadAndUnpackClangPackage(sys.platform)
- if 'win' in target_os:
- DownloadAndUnpackClangPackage('win32', runtimes_only=True)
- if sys.platform == 'win32':
- CopyDiaDllTo(os.path.join(LLVM_BUILD_DIR, 'bin'))
- WriteStampFile(expected_stamp)
- return 0
-
- if args.with_android and not os.path.exists(ANDROID_NDK_DIR):
- print 'Android NDK not found at ' + ANDROID_NDK_DIR
- print 'The Android NDK is needed to build a Clang whose -fsanitize=address'
- print 'works on Android. See '
- print 'https://www.chromium.org/developers/how-tos/android-build-instructions'
- print 'for how to install the NDK, or pass --without-android.'
- return 1
-
- if args.with_fuchsia and not os.path.exists(FUCHSIA_SDK_DIR):
- print 'Fuchsia SDK not found at ' + FUCHSIA_SDK_DIR
- print 'The Fuchsia SDK is needed to build libclang_rt for Fuchsia.'
- print 'Install the Fuchsia SDK by adding fuchsia to the '
- print 'target_os section in your .gclient and running hooks, '
- print 'or pass --without-fuchsia.'
- print 'https://chromium.googlesource.com/chromium/src/+/master/docs/fuchsia_build_instructions.md'
- print 'for general Fuchsia build instructions.'
- return 1
-
- print 'Locally building Clang %s...' % PACKAGE_VERSION
-
- AddCMakeToPath(args)
- AddGnuWinToPath()
-
- DeleteChromeToolsShim()
-
- CheckoutRepos(args)
-
- if args.skip_build:
- return
-
- cc, cxx = None, None
- libstdcpp = None
-
- cflags = []
- cxxflags = []
- ldflags = []
-
- targets = 'AArch64;ARM;Mips;PowerPC;SystemZ;WebAssembly;X86'
- base_cmake_args = ['-GNinja',
- '-DCMAKE_BUILD_TYPE=Release',
- '-DLLVM_ENABLE_ASSERTIONS=ON',
- '-DLLVM_ENABLE_TERMINFO=OFF',
- '-DLLVM_TARGETS_TO_BUILD=' + targets,
- # Statically link MSVCRT to avoid DLL dependencies.
- '-DLLVM_USE_CRT_RELEASE=MT',
- '-DCLANG_PLUGIN_SUPPORT=OFF',
- '-DCLANG_ENABLE_STATIC_ANALYZER=OFF',
- '-DCLANG_ENABLE_ARCMT=OFF',
- ]
-
- if sys.platform != 'win32':
- # libxml2 is required by the Win manifest merging tool used in cross-builds.
- base_cmake_args.append('-DLLVM_ENABLE_LIBXML2=FORCE_ON')
-
- if args.bootstrap:
- print 'Building bootstrap compiler'
- EnsureDirExists(LLVM_BOOTSTRAP_DIR)
- os.chdir(LLVM_BOOTSTRAP_DIR)
- bootstrap_args = base_cmake_args + [
- '-DLLVM_TARGETS_TO_BUILD=X86;ARM;AArch64',
- '-DCMAKE_INSTALL_PREFIX=' + LLVM_BOOTSTRAP_INSTALL_DIR,
- '-DCMAKE_C_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags),
- ]
- if cc is not None: bootstrap_args.append('-DCMAKE_C_COMPILER=' + cc)
- if cxx is not None: bootstrap_args.append('-DCMAKE_CXX_COMPILER=' + cxx)
- RmCmakeCache('.')
- RunCommand(['cmake'] + bootstrap_args + [LLVM_DIR], msvc_arch='x64')
- RunCommand(['ninja'], msvc_arch='x64')
- if args.run_tests:
- if sys.platform == 'win32':
- CopyDiaDllTo(os.path.join(LLVM_BOOTSTRAP_DIR, 'bin'))
- RunCommand(['ninja', 'check-all'], msvc_arch='x64')
- RunCommand(['ninja', 'install'], msvc_arch='x64')
-
- if sys.platform == 'win32':
- cc = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang-cl.exe')
- cxx = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang-cl.exe')
- # CMake has a hard time with backslashes in compiler paths:
- # https://stackoverflow.com/questions/13050827
- cc = cc.replace('\\', '/')
- cxx = cxx.replace('\\', '/')
- else:
- cc = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang')
- cxx = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'clang++')
-
- print 'Building final compiler'
-
- # LLVM uses C++11 starting in llvm 3.5. On Linux, this means libstdc++4.7+ is
- # needed, on OS X it requires libc++. clang only automatically links to libc++
- # when targeting OS X 10.9+, so add stdlib=libc++ explicitly so clang can run
- # on OS X versions as old as 10.7.
- deployment_target = ''
-
- if sys.platform == 'darwin' and args.bootstrap:
- # When building on 10.9, /usr/include usually doesn't exist, and while
- # Xcode's clang automatically sets a sysroot, self-built clangs don't.
- cflags = ['-isysroot', subprocess.check_output(
- ['xcrun', '--show-sdk-path']).rstrip()]
- cxxflags = ['-stdlib=libc++'] + cflags
- ldflags += ['-stdlib=libc++']
- deployment_target = '10.7'
- # Running libc++ tests takes a long time. Since it was only needed for
- # the install step above, don't build it as part of the main build.
- # This makes running package.py over 10% faster (30 min instead of 34 min)
- RmTree(LIBCXX_DIR)
-
-
- # If building at head, define a macro that plugins can use for #ifdefing
- # out code that builds at head, but not at CLANG_REVISION or vice versa.
- if use_head_revision:
- cflags += ['-DLLVM_FORCE_HEAD_REVISION']
- cxxflags += ['-DLLVM_FORCE_HEAD_REVISION']
-
- # Build PDBs for archival on Windows. Don't use RelWithDebInfo since it
- # has different optimization defaults than Release.
- # Also disable stack cookies (/GS-) for performance.
- if sys.platform == 'win32':
- cflags += ['/Zi', '/GS-']
- cxxflags += ['/Zi', '/GS-']
- ldflags += ['/DEBUG', '/OPT:REF', '/OPT:ICF']
-
- deployment_env = None
- if deployment_target:
- deployment_env = os.environ.copy()
- deployment_env['MACOSX_DEPLOYMENT_TARGET'] = deployment_target
-
- # Build lld and code coverage tools. This is done separately from the rest of
- # the build because these tools require threading support.
- tools_with_threading = [ 'lld', 'llvm-cov', 'llvm-profdata' ]
- print 'Building the following tools with threading support: %s' % (
- str(tools_with_threading))
-
- if os.path.exists(THREADS_ENABLED_BUILD_DIR):
- RmTree(THREADS_ENABLED_BUILD_DIR)
- EnsureDirExists(THREADS_ENABLED_BUILD_DIR)
- os.chdir(THREADS_ENABLED_BUILD_DIR)
-
- threads_enabled_cmake_args = base_cmake_args + [
- '-DCMAKE_C_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags),
- '-DCMAKE_EXE_LINKER_FLAGS=' + ' '.join(ldflags),
- '-DCMAKE_SHARED_LINKER_FLAGS=' + ' '.join(ldflags),
- '-DCMAKE_MODULE_LINKER_FLAGS=' + ' '.join(ldflags)]
- if cc is not None:
- threads_enabled_cmake_args.append('-DCMAKE_C_COMPILER=' + cc)
- if cxx is not None:
- threads_enabled_cmake_args.append('-DCMAKE_CXX_COMPILER=' + cxx)
-
- if args.lto_lld:
- # Build lld with LTO. That speeds up the linker by ~10%.
- # We only use LTO for Linux now.
- #
- # The linker expects all archive members to have symbol tables, so the
- # archiver needs to be able to create symbol tables for bitcode files.
- # GNU ar and ranlib don't understand bitcode files, but llvm-ar and
- # llvm-ranlib do, so use them.
- ar = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'llvm-ar')
- ranlib = os.path.join(LLVM_BOOTSTRAP_INSTALL_DIR, 'bin', 'llvm-ranlib')
- threads_enabled_cmake_args += [
- '-DCMAKE_AR=' + ar,
- '-DCMAKE_RANLIB=' + ranlib,
- '-DLLVM_ENABLE_LTO=thin',
- '-DLLVM_USE_LINKER=lld']
-
- RmCmakeCache('.')
- RunCommand(['cmake'] + threads_enabled_cmake_args + [LLVM_DIR],
- msvc_arch='x64', env=deployment_env)
- RunCommand(['ninja'] + tools_with_threading, msvc_arch='x64')
-
- # Build clang and other tools.
- CreateChromeToolsShim()
-
- cmake_args = []
- # TODO(thakis): Unconditionally append this to base_cmake_args instead once
- # compiler-rt can build with clang-cl on Windows (http://llvm.org/PR23698)
- cc_args = base_cmake_args if sys.platform != 'win32' else cmake_args
- if cc is not None: cc_args.append('-DCMAKE_C_COMPILER=' + cc)
- if cxx is not None: cc_args.append('-DCMAKE_CXX_COMPILER=' + cxx)
- default_tools = ['plugins', 'blink_gc_plugin', 'translation_unit']
- chrome_tools = list(set(default_tools + args.extra_tools))
- cmake_args += base_cmake_args + [
- '-DLLVM_ENABLE_THREADS=OFF',
- '-DCMAKE_C_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags),
- '-DCMAKE_EXE_LINKER_FLAGS=' + ' '.join(ldflags),
- '-DCMAKE_SHARED_LINKER_FLAGS=' + ' '.join(ldflags),
- '-DCMAKE_MODULE_LINKER_FLAGS=' + ' '.join(ldflags),
- '-DCMAKE_INSTALL_PREFIX=' + LLVM_BUILD_DIR,
- '-DCHROMIUM_TOOLS_SRC=%s' % os.path.join(CHROMIUM_DIR, 'tools', 'clang'),
- '-DCHROMIUM_TOOLS=%s' % ';'.join(chrome_tools)]
-
- EnsureDirExists(LLVM_BUILD_DIR)
- os.chdir(LLVM_BUILD_DIR)
- RmCmakeCache('.')
- RunCommand(['cmake'] + cmake_args + [LLVM_DIR],
- msvc_arch='x64', env=deployment_env)
- RunCommand(['ninja'], msvc_arch='x64')
-
- # Copy in the threaded versions of lld and other tools.
- if sys.platform == 'win32':
- CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', 'lld-link.exe'),
- os.path.join(LLVM_BUILD_DIR, 'bin'))
- CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', 'lld.pdb'),
- os.path.join(LLVM_BUILD_DIR, 'bin'))
- else:
- for tool in tools_with_threading:
- CopyFile(os.path.join(THREADS_ENABLED_BUILD_DIR, 'bin', tool),
- os.path.join(LLVM_BUILD_DIR, 'bin'))
-
- if chrome_tools:
- # If any Chromium tools were built, install those now.
- RunCommand(['ninja', 'cr-install'], msvc_arch='x64')
-
- VeryifyVersionOfBuiltClangMatchesVERSION()
-
- # Do an out-of-tree build of compiler-rt.
- # On Windows, this is used to get the 32-bit ASan run-time.
- # TODO(hans): Remove once the regular build above produces this.
- # On Mac and Linux, this is used to get the regular 64-bit run-time.
- # Do a clobbered build due to cmake changes.
- if os.path.isdir(COMPILER_RT_BUILD_DIR):
- RmTree(COMPILER_RT_BUILD_DIR)
- os.makedirs(COMPILER_RT_BUILD_DIR)
- os.chdir(COMPILER_RT_BUILD_DIR)
- # TODO(thakis): Add this once compiler-rt can build with clang-cl (see
- # above).
- #if args.bootstrap and sys.platform == 'win32':
- # The bootstrap compiler produces 64-bit binaries by default.
- #cflags += ['-m32']
- #cxxflags += ['-m32']
- compiler_rt_args = base_cmake_args + [
- '-DLLVM_ENABLE_THREADS=OFF',
- '-DCMAKE_C_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_CXX_FLAGS=' + ' '.join(cxxflags)]
- if sys.platform == 'darwin':
- compiler_rt_args += ['-DCOMPILER_RT_ENABLE_IOS=ON']
- if sys.platform != 'win32':
- compiler_rt_args += ['-DLLVM_CONFIG_PATH=' +
- os.path.join(LLVM_BUILD_DIR, 'bin', 'llvm-config'),
- '-DSANITIZER_MIN_OSX_VERSION="10.7"']
- # compiler-rt is part of the llvm checkout on Windows but a stand-alone
- # directory elsewhere, see the TODO above COMPILER_RT_DIR.
- RmCmakeCache('.')
- RunCommand(['cmake'] + compiler_rt_args +
- [LLVM_DIR if sys.platform == 'win32' else COMPILER_RT_DIR],
- msvc_arch='x86', env=deployment_env)
- RunCommand(['ninja', 'compiler-rt'], msvc_arch='x86')
- if sys.platform != 'win32':
- RunCommand(['ninja', 'fuzzer'])
-
- # Copy select output to the main tree.
- # TODO(hans): Make this (and the .gypi and .isolate files) version number
- # independent.
- if sys.platform == 'win32':
- platform = 'windows'
- elif sys.platform == 'darwin':
- platform = 'darwin'
- else:
- assert sys.platform.startswith('linux')
- platform = 'linux'
- rt_lib_src_dir = os.path.join(COMPILER_RT_BUILD_DIR, 'lib', platform)
- if sys.platform == 'win32':
- # TODO(thakis): This too is due to compiler-rt being part of the checkout
- # on Windows, see TODO above COMPILER_RT_DIR.
- rt_lib_src_dir = os.path.join(COMPILER_RT_BUILD_DIR, 'lib', 'clang',
- VERSION, 'lib', platform)
- rt_lib_dst_dir = os.path.join(LLVM_BUILD_DIR, 'lib', 'clang', VERSION, 'lib',
- platform)
- # Blacklists:
- CopyDirectoryContents(os.path.join(rt_lib_src_dir, '..', '..', 'share'),
- os.path.join(rt_lib_dst_dir, '..', '..', 'share'))
- # Headers:
- if sys.platform != 'win32':
- CopyDirectoryContents(
- os.path.join(COMPILER_RT_BUILD_DIR, 'include/sanitizer'),
- os.path.join(LLVM_BUILD_DIR, 'lib/clang', VERSION, 'include/sanitizer'))
- # Static and dynamic libraries:
- CopyDirectoryContents(rt_lib_src_dir, rt_lib_dst_dir)
- if sys.platform == 'darwin':
- for dylib in glob.glob(os.path.join(rt_lib_dst_dir, '*.dylib')):
- # Fix LC_ID_DYLIB for the ASan dynamic libraries to be relative to
- # @executable_path.
- # TODO(glider): this is transitional. We'll need to fix the dylib
- # name either in our build system, or in Clang. See also
- # http://crbug.com/344836.
- subprocess.call(['install_name_tool', '-id',
- '@executable_path/' + os.path.basename(dylib), dylib])
-
- if args.with_android:
- make_toolchain = os.path.join(
- ANDROID_NDK_DIR, 'build', 'tools', 'make_standalone_toolchain.py')
- for target_arch in ['aarch64', 'arm', 'i686']:
- # Make standalone Android toolchain for target_arch.
- toolchain_dir = os.path.join(
- LLVM_BUILD_DIR, 'android-toolchain-' + target_arch)
- api_level = '21' if target_arch == 'aarch64' else '19'
- RunCommand([
- make_toolchain,
- '--api=' + api_level,
- '--force',
- '--install-dir=%s' % toolchain_dir,
- '--stl=libc++',
- '--arch=' + {
- 'aarch64': 'arm64',
- 'arm': 'arm',
- 'i686': 'x86',
- }[target_arch]])
-
- # NDK r16 "helpfully" installs libc++ as libstdc++ "so the compiler will
- # pick it up by default". Only these days, the compiler tries to find
- # libc++ instead. See https://crbug.com/902270.
- shutil.copy(os.path.join(toolchain_dir, 'sysroot/usr/lib/libstdc++.a'),
- os.path.join(toolchain_dir, 'sysroot/usr/lib/libc++.a'))
- shutil.copy(os.path.join(toolchain_dir, 'sysroot/usr/lib/libstdc++.so'),
- os.path.join(toolchain_dir, 'sysroot/usr/lib/libc++.so'))
-
- # Build compiler-rt runtimes needed for Android in a separate build tree.
- build_dir = os.path.join(LLVM_BUILD_DIR, 'android-' + target_arch)
- if not os.path.exists(build_dir):
- os.mkdir(os.path.join(build_dir))
- os.chdir(build_dir)
- target_triple = target_arch
- abi_libs = 'c++abi'
- if target_arch == 'arm':
- target_triple = 'armv7'
- abi_libs += ';unwind'
- target_triple += '-linux-android' + api_level
- cflags = ['--target=%s' % target_triple,
- '--sysroot=%s/sysroot' % toolchain_dir,
- '-B%s' % toolchain_dir]
- android_args = base_cmake_args + [
- '-DLLVM_ENABLE_THREADS=OFF',
- '-DCMAKE_C_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'),
- '-DCMAKE_CXX_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang++'),
- '-DLLVM_CONFIG_PATH=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-config'),
- '-DCMAKE_C_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_CXX_FLAGS=' + ' '.join(cflags),
- '-DCMAKE_ASM_FLAGS=' + ' '.join(cflags),
- '-DSANITIZER_CXX_ABI=none',
- '-DSANITIZER_CXX_ABI_LIBRARY=' + abi_libs,
- '-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-u__cxa_demangle',
- '-DANDROID=1']
- RmCmakeCache('.')
- RunCommand(['cmake'] + android_args + [COMPILER_RT_DIR])
-
- # We use ASan i686 build for fuzzing.
- libs_want = ['lib/linux/libclang_rt.asan-{0}-android.so']
- if target_arch in ['aarch64', 'arm']:
- libs_want += [
- 'lib/linux/libclang_rt.ubsan_standalone-{0}-android.so',
- 'lib/linux/libclang_rt.profile-{0}-android.a',
- ]
- if target_arch == 'aarch64':
- libs_want += ['lib/linux/libclang_rt.hwasan-{0}-android.so']
- libs_want = [lib.format(target_arch) for lib in libs_want]
- RunCommand(['ninja'] + libs_want)
-
- # And copy them into the main build tree.
- for p in libs_want:
- shutil.copy(p, rt_lib_dst_dir)
-
- if args.with_fuchsia:
- # Fuchsia links against libclang_rt.builtins-<arch>.a instead of libgcc.a.
- for target_arch in ['aarch64', 'x86_64']:
- fuchsia_arch_name = {'aarch64': 'arm64', 'x86_64': 'x64'}[target_arch]
- toolchain_dir = os.path.join(
- FUCHSIA_SDK_DIR, 'arch', fuchsia_arch_name, 'sysroot')
- # Build clang_rt runtime for Fuchsia in a separate build tree.
- build_dir = os.path.join(LLVM_BUILD_DIR, 'fuchsia-' + target_arch)
- if not os.path.exists(build_dir):
- os.mkdir(os.path.join(build_dir))
- os.chdir(build_dir)
- target_spec = target_arch + '-fuchsia'
- # TODO(thakis): Might have to pass -B here once sysroot contains
- # binaries (e.g. gas for arm64?)
- fuchsia_args = base_cmake_args + [
- '-DLLVM_ENABLE_THREADS=OFF',
- '-DCMAKE_C_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'),
- '-DCMAKE_CXX_COMPILER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang++'),
- '-DCMAKE_LINKER=' + os.path.join(LLVM_BUILD_DIR, 'bin/clang'),
- '-DCMAKE_AR=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-ar'),
- '-DLLVM_CONFIG_PATH=' + os.path.join(LLVM_BUILD_DIR, 'bin/llvm-config'),
- '-DCMAKE_SYSTEM_NAME=Fuchsia',
- '-DCMAKE_C_COMPILER_TARGET=%s-fuchsia' % target_arch,
- '-DCMAKE_ASM_COMPILER_TARGET=%s-fuchsia' % target_arch,
- '-DCOMPILER_RT_DEFAULT_TARGET_ONLY=ON',
- '-DCMAKE_SYSROOT=%s' % toolchain_dir,
- # TODO(thakis|scottmg): Use PER_TARGET_RUNTIME_DIR for all platforms.
- # https://crbug.com/882485.
- '-DLLVM_ENABLE_PER_TARGET_RUNTIME_DIR=ON',
-
- # These are necessary because otherwise CMake tries to build an
- # executable to test to see if the compiler is working, but in doing so,
- # it links against the builtins.a that we're about to build.
- '-DCMAKE_C_COMPILER_WORKS=ON',
- '-DCMAKE_ASM_COMPILER_WORKS=ON',
- ]
- RmCmakeCache('.')
- RunCommand(['cmake'] +
- fuchsia_args +
- [os.path.join(COMPILER_RT_DIR, 'lib', 'builtins')])
- builtins_a = 'libclang_rt.builtins.a'
- RunCommand(['ninja', builtins_a])
-
- # And copy it into the main build tree.
- fuchsia_lib_dst_dir = os.path.join(LLVM_BUILD_DIR, 'lib', 'clang',
- VERSION, target_spec, 'lib')
- if not os.path.exists(fuchsia_lib_dst_dir):
- os.makedirs(fuchsia_lib_dst_dir)
- CopyFile(os.path.join(build_dir, target_spec, 'lib', builtins_a),
- fuchsia_lib_dst_dir)
-
- # Run tests.
- if args.run_tests or use_head_revision:
- os.chdir(LLVM_BUILD_DIR)
- RunCommand(['ninja', 'cr-check-all'], msvc_arch='x64')
- if args.run_tests:
- if sys.platform == 'win32':
- CopyDiaDllTo(os.path.join(LLVM_BUILD_DIR, 'bin'))
- os.chdir(LLVM_BUILD_DIR)
- RunCommand(['ninja', 'check-all'], msvc_arch='x64')
-
- WriteStampFile(PACKAGE_VERSION)
- print 'Clang update was successful.'
- return 0
-
-
-def gn_arg(v):
- if v == 'True':
- return True
- if v == 'False':
- return False
- raise argparse.ArgumentTypeError('Expected one of %r or %r' % (
- 'True', 'False'))
-
-
-def main():
- parser = argparse.ArgumentParser(description='Build Clang.')
- parser.add_argument('--bootstrap', action='store_true',
- help='first build clang with CC, then with itself.')
- parser.add_argument('--force-local-build', action='store_true',
- help="don't try to download prebuild binaries")
- parser.add_argument('--gcc-toolchain', help='set the version for which gcc '
- 'version be used for building; --gcc-toolchain=/opt/foo '
- 'picks /opt/foo/bin/gcc')
- parser.add_argument('--lto-lld', action='store_true',
- help='build lld with LTO')
- parser.add_argument('--llvm-force-head-revision', action='store_true',
- help=('use the revision in the repo when printing '
- 'the revision'))
- parser.add_argument('--print-revision', action='store_true',
- help='print current clang revision and exit.')
- parser.add_argument('--print-clang-version', action='store_true',
- help='print current clang version (e.g. x.y.z) and exit.')
- parser.add_argument('--run-tests', action='store_true',
- help='run tests after building; only for local builds')
- parser.add_argument('--skip-build', action='store_true',
- help='do not build anything')
- parser.add_argument('--skip-checkout', action='store_true',
- help='do not create or update any checkouts')
- parser.add_argument('--extra-tools', nargs='*', default=[],
- help='select additional chrome tools to build')
- parser.add_argument('--use-system-cmake', action='store_true',
- help='use the cmake from PATH instead of downloading '
- 'and using prebuilt cmake binaries')
- parser.add_argument('--verify-version',
- help='verify that clang has the passed-in version')
- parser.add_argument('--with-android', type=gn_arg, nargs='?', const=True,
- help='build the Android ASan runtime (linux only)',
- default=sys.platform.startswith('linux'))
- parser.add_argument('--without-android', action='store_false',
- help='don\'t build Android ASan runtime (linux only)',
- dest='with_android')
- parser.add_argument('--without-fuchsia', action='store_false',
- help='don\'t build Fuchsia clang_rt runtime (linux/mac)',
- dest='with_fuchsia',
- default=sys.platform in ('linux2', 'darwin'))
- args = parser.parse_args()
-
- if args.lto_lld and not args.bootstrap:
- print '--lto-lld requires --bootstrap'
- return 1
- if args.lto_lld and not sys.platform.startswith('linux'):
- print '--lto-lld is only effective on Linux. Ignoring the option.'
- args.lto_lld = False
-
- # Get svn if we're going to use it to check the revision or do a local build.
- if (use_head_revision or args.llvm_force_head_revision or
- args.force_local_build):
- AddSvnToPathOnWin()
-
- if args.verify_version and args.verify_version != VERSION:
- print 'VERSION is %s but --verify-version argument was %s, exiting.' % (
- VERSION, args.verify_version)
- print 'clang_version in build/toolchain/toolchain.gni is likely outdated.'
- return 1
-
- global CLANG_REVISION, PACKAGE_VERSION
- if args.print_revision:
- if use_head_revision or args.llvm_force_head_revision:
- print GetSvnRevision(LLVM_DIR)
- else:
- print PACKAGE_VERSION
- return 0
-
- if args.print_clang_version:
- sys.stdout.write(VERSION)
- return 0
-
- # Don't buffer stdout, so that print statements are immediately flushed.
- # Do this only after --print-revision has been handled, else we'll get
- # an error message when this script is run from gn for some reason.
- sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
-
- if use_head_revision:
- # Use a real revision number rather than HEAD to make sure that the stamp
- # file logic works.
- CLANG_REVISION = GetSvnRevision(LLVM_REPO_URL)
- PACKAGE_VERSION = CLANG_REVISION + '-0'
-
- args.force_local_build = True
- # Don't build fuchsia runtime on ToT bots at all.
- args.with_fuchsia = False
-
- return UpdateClang(args)
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/chromium/third_party/openscreen/src/tools/download-clang-update-script.py b/chromium/third_party/openscreen/src/tools/download-clang-update-script.py
new file mode 100755
index 00000000000..0d707060920
--- /dev/null
+++ b/chromium/third_party/openscreen/src/tools/download-clang-update-script.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""This script is used to download the clang update script. It runs as a
+gclient hook.
+
+It's equivalent to using curl to download the latest update script:
+
+ $ curl --silent --create-dirs -o tools/clang/scripts/update.py \
+ https://raw.githubusercontent.com/chromium/chromium/master/tools/clang/scripts/update.py
+
+The purpose of "reinventing the wheel" with this script is just so developers
+aren't required to have curl installed.
+"""
+
+import argparse
+import os
+import sys
+
+try:
+ from urllib2 import HTTPError, URLError, urlopen
+except ImportError: # For Py3 compatibility
+ from urllib.error import HTTPError, URLError
+ from urllib.request import urlopen
+
+SCRIPT_DOWNLOAD_URL = ('https://raw.githubusercontent.com/' +
+ 'chromium/chromium/master/tools/clang/scripts/update.py')
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Download clang update script from chromium master.')
+ parser.add_argument('--output',
+ help='Path to script file to create/overwrite.')
+ args = parser.parse_args()
+
+ if not args.output:
+ print('usage: download-clang-update-script.py ' +
+ '--output=tools/clang/scripts/update.py');
+ return 1
+
+ script_contents = ''
+ try:
+ response = urlopen(SCRIPT_DOWNLOAD_URL)
+ script_contents = response.read()
+ except HTTPError as e:
+ print e.code
+ print e.read()
+ return 1
+ except URLError as e:
+ print 'Download failed. Reason: ', e.reason
+ return 1
+
+ directory = os.path.dirname(args.output)
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ script_file = open(args.output, 'w')
+ script_file.write(script_contents)
+ script_file.close()
+
+ return 0
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/chromium/third_party/openscreen/src/tools/install-build-tools.sh b/chromium/third_party/openscreen/src/tools/install-build-tools.sh
deleted file mode 100755
index 65fb4643969..00000000000
--- a/chromium/third_party/openscreen/src/tools/install-build-tools.sh
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 The Chromium Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-
-uname="$(uname -s)"
-case "${uname}" in
- Linux*) env="linux64";;
- Darwin*) env="mac";;
-esac
-
-echo "Assuming we are running in $env..."
-
-ninja_version="v1.9.0"
-ninja_zipfile=""
-case "$env" in
- linux64) ninja_zipfile="ninja-linux.zip";;
- mac) ninja_zipfile="ninja-mac.zip";;
-esac
-
-GOOGLE_STORAGE_URL="https://storage.googleapis.com"
-BUILDTOOLS_ROOT=$(git rev-parse --show-toplevel)/buildtools/$env
-if [ ! -d $BUILDTOOLS_ROOT ]; then
- mkdir -p $BUILDTOOLS_ROOT
-fi
-
-pushd $BUILDTOOLS_ROOT
-set -x # echo on
-sha1=$(tail -c+1 $BUILDTOOLS_ROOT/clang-format.sha1)
-curl -Lo clang-format "$GOOGLE_STORAGE_URL/chromium-clang-format/$sha1"
-chmod +x clang-format
-curl -L "https://github.com/ninja-build/ninja/releases/download/${ninja_version}/${ninja_zipfile}" | funzip > ninja
-chmod +x ninja
-set +x # echo off
-popd
-
diff --git a/chromium/third_party/openscreen/src/util/BUILD.gn b/chromium/third_party/openscreen/src/util/BUILD.gn
index 8f1958e2a4f..12cc00a03bf 100644
--- a/chromium/third_party/openscreen/src/util/BUILD.gn
+++ b/chromium/third_party/openscreen/src/util/BUILD.gn
@@ -12,6 +12,8 @@ source_set("util") {
"big_endian.h",
"crypto/certificate_utils.cc",
"crypto/certificate_utils.h",
+ "crypto/digest_sign.cc",
+ "crypto/digest_sign.h",
"crypto/openssl_util.cc",
"crypto/openssl_util.h",
"crypto/rsa_private_key.cc",
@@ -20,30 +22,38 @@ source_set("util") {
"crypto/secure_hash.h",
"crypto/sha2.cc",
"crypto/sha2.h",
+ "hashing.h",
"integer_division.h",
- "json/json_reader.cc",
- "json/json_reader.h",
- "json/json_writer.cc",
- "json/json_writer.h",
+ "json/json_serialization.cc",
+ "json/json_serialization.h",
+ "json/json_value.cc",
+ "json/json_value.h",
"logging.h",
"operation_loop.cc",
"operation_loop.h",
"saturate_cast.h",
"serial_delete_ptr.h",
+ "simple_fraction.cc",
+ "simple_fraction.h",
"std_util.h",
+ "stringprintf.cc",
"stringprintf.h",
"trace_logging.h",
"trace_logging/macro_support.h",
"trace_logging/scoped_trace_operations.cc",
"trace_logging/scoped_trace_operations.h",
+ "weak_ptr.h",
"yet_another_bit_vector.cc",
"yet_another_bit_vector.h",
]
+ public_deps = [
+ "../third_party/jsoncpp",
+ ]
+
deps = [
"../third_party/abseil",
"../third_party/boringssl",
- "../third_party/jsoncpp",
]
public_configs = [ "../build:openscreen_include_dirs" ]
@@ -60,13 +70,16 @@ source_set("unittests") {
"crypto/secure_hash_unittest.cc",
"crypto/sha2_unittest.cc",
"integer_division_unittest.cc",
- "json/json_reader_unittest.cc",
- "json/json_writer_unittest.cc",
+ "json/json_serialization_unittest.cc",
+ "json/json_value_unittest.cc",
"operation_loop_unittest.cc",
"saturate_cast_unittest.cc",
"serial_delete_ptr_unittest.cc",
+ "simple_fraction_unittest.cc",
+ "stringprintf_unittest.cc",
"trace_logging/scoped_trace_operations_unittest.cc",
"trace_logging_unittest.cc",
+ "weak_ptr_unittest.cc",
"yet_another_bit_vector_unittest.cc",
]
diff --git a/chromium/third_party/openscreen/src/util/alarm.cc b/chromium/third_party/openscreen/src/util/alarm.cc
index c72815a19c5..a8427285a9b 100644
--- a/chromium/third_party/openscreen/src/util/alarm.cc
+++ b/chromium/third_party/openscreen/src/util/alarm.cc
@@ -58,8 +58,7 @@ class Alarm::CancelableFunctor {
Alarm* alarm_;
};
-Alarm::Alarm(platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner)
+Alarm::Alarm(ClockNowFunctionPtr now_function, TaskRunner* task_runner)
: now_function_(now_function), task_runner_(task_runner) {
OSP_DCHECK(now_function_);
OSP_DCHECK(task_runner_);
@@ -73,29 +72,30 @@ Alarm::~Alarm() {
}
void Alarm::Cancel() {
- scheduled_task_ = platform::TaskRunner::Task();
+ scheduled_task_ = TaskRunner::Task();
}
-void Alarm::ScheduleWithTask(platform::TaskRunner::Task task,
- platform::Clock::time_point alarm_time) {
+void Alarm::ScheduleWithTask(TaskRunner::Task task,
+ Clock::time_point desired_alarm_time) {
OSP_DCHECK(task.valid());
scheduled_task_ = std::move(task);
- alarm_time_ = alarm_time;
+
+ const Clock::time_point now = now_function_();
+ alarm_time_ = std::max(now, desired_alarm_time);
// Ensure that a later firing will occur, and not too late.
if (queued_fire_) {
- if (next_fire_time_ <= alarm_time) {
+ if (next_fire_time_ <= alarm_time_) {
return;
}
queued_fire_->Cancel();
OSP_DCHECK(!queued_fire_);
}
- InvokeLater(now_function_(), alarm_time);
+ InvokeLater(now, alarm_time_);
}
-void Alarm::InvokeLater(platform::Clock::time_point now,
- platform::Clock::time_point fire_time) {
+void Alarm::InvokeLater(Clock::time_point now, Clock::time_point fire_time) {
OSP_DCHECK(!queued_fire_);
next_fire_time_ = fire_time;
// Note: Instantiating the CancelableFunctor below sets |this->queued_fire_|.
@@ -109,7 +109,7 @@ void Alarm::TryInvoke() {
// If this is an early firing, re-schedule for later. This happens if
// Schedule() was called again before this firing had occurred.
- const platform::Clock::time_point now = now_function_();
+ const Clock::time_point now = now_function_();
if (now < alarm_time_) {
InvokeLater(now, alarm_time_);
return;
@@ -119,8 +119,11 @@ void Alarm::TryInvoke() {
// itself: a) calls any Alarm methods re-entrantly, or b) causes the
// destruction of this Alarm instance.
// WARNING: |this| is not valid after here!
- platform::TaskRunner::Task task = std::move(scheduled_task_);
+ TaskRunner::Task task = std::move(scheduled_task_);
task();
}
+// static
+constexpr Clock::time_point Alarm::kImmediately;
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/alarm.h b/chromium/third_party/openscreen/src/util/alarm.h
index dbf75714a56..551c0b76217 100644
--- a/chromium/third_party/openscreen/src/util/alarm.h
+++ b/chromium/third_party/openscreen/src/util/alarm.h
@@ -31,8 +31,7 @@ namespace openscreen {
// running the client's Task later; or c) runs the client's Task.
class Alarm {
public:
- Alarm(platform::ClockNowFunctionPtr now_function,
- platform::TaskRunner* task_runner);
+ Alarm(ClockNowFunctionPtr now_function, TaskRunner* task_runner);
~Alarm();
// The design requires that Alarm instances not be copied or moved.
@@ -44,20 +43,18 @@ class Alarm {
// Schedule the |functor| to be invoked at |alarm_time|. If this Alarm was
// already scheduled, the prior scheduling is canceled. The Functor can be any
// callable target (e.g., function, lambda-expression, std::bind result,
- // etc.).
+ // etc.). If |alarm_time| is on or before "now," such as kImmediately, it is
+ // scheduled to run as soon as possible.
template <typename Functor>
- inline void Schedule(Functor functor,
- platform::Clock::time_point alarm_time) {
- ScheduleWithTask(platform::TaskRunner::Task(std::move(functor)),
- alarm_time);
+ inline void Schedule(Functor functor, Clock::time_point alarm_time) {
+ ScheduleWithTask(TaskRunner::Task(std::move(functor)), alarm_time);
}
// Same as Schedule(), but invoke the functor at the given |delay| after right
// now.
template <typename Functor>
- inline void ScheduleFromNow(Functor functor,
- platform::Clock::duration delay) {
- ScheduleWithTask(platform::TaskRunner::Task(std::move(functor)),
+ inline void ScheduleFromNow(Functor functor, Clock::duration delay) {
+ ScheduleWithTask(TaskRunner::Task(std::move(functor)),
now_function_() + delay);
}
@@ -67,8 +64,10 @@ class Alarm {
// See comments for Schedule(). Generally, callers will want to call
// Schedule() instead of this, for more-convenient caller-side syntax, unless
// they already have a Task to pass-in.
- void ScheduleWithTask(platform::TaskRunner::Task task,
- platform::Clock::time_point alarm_time);
+ void ScheduleWithTask(TaskRunner::Task task, Clock::time_point alarm_time);
+
+ // A special time_point value representing "as soon as possible."
+ static constexpr Clock::time_point kImmediately = Clock::time_point::min();
private:
// A move-only functor that holds a raw pointer back to |this| and can be
@@ -77,20 +76,19 @@ class Alarm {
class CancelableFunctor;
// Posts a delayed call to TryInvoke() to the TaskRunner.
- void InvokeLater(platform::Clock::time_point now,
- platform::Clock::time_point fire_time);
+ void InvokeLater(Clock::time_point now, Clock::time_point fire_time);
// Examines whether to invoke the client's Task now; or try again later; or
// just do nothing. See class-level design comments.
void TryInvoke();
- const platform::ClockNowFunctionPtr now_function_;
- platform::TaskRunner* const task_runner_;
+ const ClockNowFunctionPtr now_function_;
+ TaskRunner* const task_runner_;
// This is the task the client wants to have run at a specific point-in-time.
// This is NOT the task that Alarm provides to the TaskRunner.
- platform::TaskRunner::Task scheduled_task_;
- platform::Clock::time_point alarm_time_{};
+ TaskRunner::Task scheduled_task_;
+ Clock::time_point alarm_time_{};
// When non-null, there is a task in the TaskRunner's queue that will call
// TryInvoke() some time in the future. This member is exclusively maintained
@@ -99,7 +97,7 @@ class Alarm {
// When the CancelableFunctor is scheduled to run. It may possibly execute
// later than this, if the TaskRunner is falling behind.
- platform::Clock::time_point next_fire_time_{};
+ Clock::time_point next_fire_time_{};
};
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/alarm_unittest.cc b/chromium/third_party/openscreen/src/util/alarm_unittest.cc
index 71f570a058a..5fad74bf7f4 100644
--- a/chromium/third_party/openscreen/src/util/alarm_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/alarm_unittest.cc
@@ -15,30 +15,28 @@ namespace {
class AlarmTest : public testing::Test {
public:
- platform::FakeClock* clock() { return &clock_; }
- platform::TaskRunner* task_runner() { return &task_runner_; }
+ FakeClock* clock() { return &clock_; }
+ FakeTaskRunner* task_runner() { return &task_runner_; }
Alarm* alarm() { return &alarm_; }
private:
- platform::FakeClock clock_{platform::Clock::now()};
- platform::FakeTaskRunner task_runner_{&clock_};
- Alarm alarm_{&platform::FakeClock::now, &task_runner_};
+ FakeClock clock_{Clock::now()};
+ FakeTaskRunner task_runner_{&clock_};
+ Alarm alarm_{&FakeClock::now, &task_runner_};
};
TEST_F(AlarmTest, RunsTaskAsClockAdvances) {
- constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20);
+ constexpr Clock::duration kDelay = std::chrono::milliseconds(20);
- const platform::Clock::time_point alarm_time =
- platform::FakeClock::now() + kDelay;
- platform::Clock::time_point actual_run_time{};
- alarm()->Schedule([&]() { actual_run_time = platform::FakeClock::now(); },
- alarm_time);
+ const Clock::time_point alarm_time = FakeClock::now() + kDelay;
+ Clock::time_point actual_run_time{};
+ alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); }, alarm_time);
// Confirm the lambda did not run immediately.
- ASSERT_EQ(platform::Clock::time_point{}, actual_run_time);
+ ASSERT_EQ(Clock::time_point{}, actual_run_time);
// Confirm the lambda does not run until the necessary delay has elapsed.
clock()->Advance(kDelay / 2);
- ASSERT_EQ(platform::Clock::time_point{}, actual_run_time);
+ ASSERT_EQ(Clock::time_point{}, actual_run_time);
// Confirm the lambda is called when the necessary delay has elapsed.
clock()->Advance(kDelay / 2);
@@ -49,17 +47,34 @@ TEST_F(AlarmTest, RunsTaskAsClockAdvances) {
ASSERT_EQ(alarm_time, actual_run_time);
}
+TEST_F(AlarmTest, RunsTaskImmediately) {
+ const Clock::time_point expected_run_time = FakeClock::now();
+ Clock::time_point actual_run_time{};
+ alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); },
+ Alarm::kImmediately);
+ // Confirm the lambda did not run yet, since it should run asynchronously, in
+ // a separate TaskRunner task.
+ ASSERT_EQ(Clock::time_point{}, actual_run_time);
+
+ // Confirm the lambda runs without the clock having to tick forward.
+ task_runner()->RunTasksUntilIdle();
+ ASSERT_EQ(expected_run_time, actual_run_time);
+
+ // Confirm the lambda is only run once.
+ clock()->Advance(std::chrono::seconds(2));
+ ASSERT_EQ(expected_run_time, actual_run_time);
+}
+
TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) {
- constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20);
- constexpr platform::Clock::time_point kNever{};
+ constexpr Clock::duration kDelay = std::chrono::milliseconds(20);
+ constexpr Clock::time_point kNever{};
- platform::Clock::time_point actual_run_time{};
+ Clock::time_point actual_run_time{};
{
- Alarm scoped_alarm(&platform::FakeClock::now, task_runner());
- const platform::Clock::time_point alarm_time =
- platform::FakeClock::now() + kDelay;
- scoped_alarm.Schedule(
- [&]() { actual_run_time = platform::FakeClock::now(); }, alarm_time);
+ Alarm scoped_alarm(&FakeClock::now, task_runner());
+ const Clock::time_point alarm_time = FakeClock::now() + kDelay;
+ scoped_alarm.Schedule([&]() { actual_run_time = FakeClock::now(); },
+ alarm_time);
// |scoped_alarm| is destroyed.
}
@@ -70,31 +85,27 @@ TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) {
}
TEST_F(AlarmTest, Cancels) {
- constexpr platform::Clock::duration kDelay = std::chrono::milliseconds(20);
+ constexpr Clock::duration kDelay = std::chrono::milliseconds(20);
- const platform::Clock::time_point alarm_time =
- platform::FakeClock::now() + kDelay;
- platform::Clock::time_point actual_run_time{};
- alarm()->Schedule([&]() { actual_run_time = platform::FakeClock::now(); },
- alarm_time);
+ const Clock::time_point alarm_time = FakeClock::now() + kDelay;
+ Clock::time_point actual_run_time{};
+ alarm()->Schedule([&]() { actual_run_time = FakeClock::now(); }, alarm_time);
// Advance the clock for half the delay, and confirm the lambda has not run
// yet.
clock()->Advance(kDelay / 2);
- ASSERT_EQ(platform::Clock::time_point{}, actual_run_time);
+ ASSERT_EQ(Clock::time_point{}, actual_run_time);
// Cancel and then advance the clock well past the delay, and confirm the
// lambda has never run.
alarm()->Cancel();
clock()->Advance(kDelay * 100);
- ASSERT_EQ(platform::Clock::time_point{}, actual_run_time);
+ ASSERT_EQ(Clock::time_point{}, actual_run_time);
}
TEST_F(AlarmTest, CancelsAndRearms) {
- constexpr platform::Clock::duration kShorterDelay =
- std::chrono::milliseconds(10);
- constexpr platform::Clock::duration kLongerDelay =
- std::chrono::milliseconds(100);
+ constexpr Clock::duration kShorterDelay = std::chrono::milliseconds(10);
+ constexpr Clock::duration kLongerDelay = std::chrono::milliseconds(100);
// Run the test twice: Once when scheduling first with a long delay, then a
// shorter delay; and once when scheduling first with a short delay, then a
@@ -105,7 +116,7 @@ TEST_F(AlarmTest, CancelsAndRearms) {
const auto delay2 = do_longer_then_shorter ? kShorterDelay : kLongerDelay;
int count1 = 0;
- alarm()->Schedule([&]() { ++count1; }, platform::FakeClock::now() + delay1);
+ alarm()->Schedule([&]() { ++count1; }, FakeClock::now() + delay1);
// Advance the clock for half of |delay1|, and confirm the lambda that
// increments the variable does not run.
@@ -116,7 +127,7 @@ TEST_F(AlarmTest, CancelsAndRearms) {
// Schedule a different lambda, that increments a different variable, to run
// after |delay2|.
int count2 = 0;
- alarm()->Schedule([&]() { ++count2; }, platform::FakeClock::now() + delay2);
+ alarm()->Schedule([&]() { ++count2; }, FakeClock::now() + delay2);
// Confirm the second scheduling will fire at the right moment.
clock()->Advance(delay2 / 2);
diff --git a/chromium/third_party/openscreen/src/util/big_endian.h b/chromium/third_party/openscreen/src/util/big_endian.h
index 6c94ca5e6a6..b2067d7537e 100644
--- a/chromium/third_party/openscreen/src/util/big_endian.h
+++ b/chromium/third_party/openscreen/src/util/big_endian.h
@@ -26,43 +26,72 @@ inline bool IsBigEndianArchitecture() {
return !!bytes[0];
}
-// Returns the bytes of |x| in reverse order. This is only defined for 16-, 32-,
-// and 64-bit unsigned integers.
-template <typename Integer>
-Integer ByteSwap(Integer x);
+namespace internal {
+
+template <int size>
+struct MakeSizedUnsignedInteger;
+
+template <>
+struct MakeSizedUnsignedInteger<1> {
+ using type = uint8_t;
+};
+
+template <>
+struct MakeSizedUnsignedInteger<2> {
+ using type = uint16_t;
+};
+
+template <>
+struct MakeSizedUnsignedInteger<4> {
+ using type = uint32_t;
+};
+
+template <>
+struct MakeSizedUnsignedInteger<8> {
+ using type = uint64_t;
+};
+
+template <int size>
+inline typename MakeSizedUnsignedInteger<size>::type ByteSwap(
+ typename MakeSizedUnsignedInteger<size>::type x) {
+ static_assert(size <= 8,
+ "ByteSwap() specialization missing in " __FILE__
+ ". "
+ "Are you trying to use an integer larger than 64 bits?");
+}
template <>
-inline uint8_t ByteSwap(uint8_t x) {
+inline uint8_t ByteSwap<1>(uint8_t x) {
return x;
}
#if defined(__clang__) || defined(__GNUC__)
template <>
-inline uint64_t ByteSwap(uint64_t x) {
+inline uint64_t ByteSwap<8>(uint64_t x) {
return __builtin_bswap64(x);
}
template <>
-inline uint32_t ByteSwap(uint32_t x) {
+inline uint32_t ByteSwap<4>(uint32_t x) {
return __builtin_bswap32(x);
}
template <>
-inline uint16_t ByteSwap(uint16_t x) {
+inline uint16_t ByteSwap<2>(uint16_t x) {
return __builtin_bswap16(x);
}
#elif defined(_MSC_VER)
template <>
-inline uint64_t ByteSwap(uint64_t x) {
+inline uint64_t ByteSwap<8>(uint64_t x) {
return _byteswap_uint64(x);
}
template <>
-inline uint32_t ByteSwap(uint32_t x) {
+inline uint32_t ByteSwap<4>(uint32_t x) {
return _byteswap_ulong(x);
}
template <>
-inline uint16_t ByteSwap(uint16_t x) {
+inline uint16_t ByteSwap<2>(uint16_t x) {
return _byteswap_ushort(x);
}
@@ -71,20 +100,30 @@ inline uint16_t ByteSwap(uint16_t x) {
#include <byteswap.h>
template <>
-inline uint64_t ByteSwap(uint64_t x) {
+inline uint64_t ByteSwap<8>(uint64_t x) {
return bswap_64(x);
}
template <>
-inline uint32_t ByteSwap(uint32_t x) {
+inline uint32_t ByteSwap<4>(uint32_t x) {
return bswap_32(x);
}
template <>
-inline uint16_t ByteSwap(uint16_t x) {
+inline uint16_t ByteSwap<2>(uint16_t x) {
return bswap_16(x);
}
#endif
+} // namespace internal
+
+// Returns the bytes of |x| in reverse order. This is only defined for 16-, 32-,
+// and 64-bit unsigned integers.
+template <typename Integer>
+inline std::enable_if_t<std::is_unsigned<Integer>::value, Integer> ByteSwap(
+ Integer x) {
+ return internal::ByteSwap<sizeof(Integer)>(x);
+}
+
// Read a POD integer from |src| in big-endian byte order, returning the integer
// in native byte order.
template <typename Integer>
diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc
index 1d6873f4485..844e60554b8 100644
--- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc
+++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc
@@ -6,22 +6,31 @@
#include <openssl/asn1.h>
#include <openssl/bio.h>
+#include <openssl/bn.h>
#include <openssl/crypto.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
#include <time.h>
-#include <atomic>
#include <string>
#include "util/crypto/openssl_util.h"
#include "util/crypto/sha2.h"
+#include "util/logging.h"
namespace openscreen {
namespace {
+// These values are bit positions from RFC 5280 4.2.1.3 and will be passed to
+// ASN1_BIT_STRING_set_bit.
+enum KeyUsageBits {
+ kDigitalSignature = 0,
+ kKeyCertSign = 5,
+};
+
// Returns whether or not the certificate field successfully was added.
bool AddCertificateField(X509_NAME* certificate_name,
absl::string_view field,
@@ -41,12 +50,22 @@ bssl::UniquePtr<X509> CreateCertificateInternal(
absl::string_view name,
std::chrono::seconds certificate_duration,
EVP_PKEY key_pair,
- std::chrono::seconds time_since_unix_epoch) {
+ std::chrono::seconds time_since_unix_epoch,
+ bool make_ca,
+ X509* issuer,
+ EVP_PKEY* issuer_key) {
+ OSP_DCHECK((!!issuer) == (!!issuer_key));
bssl::UniquePtr<X509> certificate(X509_new());
+ if (!issuer) {
+ issuer = certificate.get();
+ }
+ if (!issuer_key) {
+ issuer_key = &key_pair;
+ }
// Serial numbers must be unique for this session. As a pretend CA, we should
// not issue certificates with the same serial number in the same session.
- static std::atomic_int serial_number(1);
+ static int serial_number(1);
if (ASN1_INTEGER_set(X509_get_serialNumber(certificate.get()),
serial_number++) != 1) {
return nullptr;
@@ -66,12 +85,36 @@ bssl::UniquePtr<X509> CreateCertificateInternal(
return nullptr;
}
- if ((X509_set_issuer_name(certificate.get(), certificate_name) != 1) ||
+ bssl::UniquePtr<ASN1_BIT_STRING> x(ASN1_BIT_STRING_new());
+ ASN1_BIT_STRING_set_bit(x.get(), KeyUsageBits::kDigitalSignature, 1);
+ if (make_ca) {
+ ASN1_BIT_STRING_set_bit(x.get(), KeyUsageBits::kKeyCertSign, 1);
+ }
+ if (X509_add1_ext_i2d(certificate.get(), NID_key_usage, x.get(), 0, 0) != 1) {
+ return nullptr;
+ }
+ if (make_ca) {
+ X509V3_CTX ctx;
+ X509V3_set_ctx_nodb(&ctx);
+ X509V3_set_ctx(&ctx, issuer, certificate.get(), nullptr, nullptr, 0);
+ bssl::UniquePtr<X509_EXTENSION> ex(
+ X509V3_EXT_nconf_nid(nullptr, &ctx, NID_basic_constraints,
+ const_cast<char*>("critical,CA:TRUE")));
+ if (!ex) {
+ return nullptr;
+ }
+ void* thing = X509V3_EXT_d2i(ex.get());
+ X509_add1_ext_i2d(certificate.get(), NID_basic_constraints, thing, 1, 0);
+ X509V3_EXT_free(NID_basic_constraints, thing);
+ }
+
+ X509_NAME* issuer_name = X509_get_subject_name(issuer);
+ if ((X509_set_issuer_name(certificate.get(), issuer_name) != 1) ||
(X509_set_pubkey(certificate.get(), &key_pair) != 1) ||
// Unlike all of the other BoringSSL methods here, X509_sign returns
// the size of the signature in bytes.
- (X509_sign(certificate.get(), &key_pair, EVP_sha256()) <= 0) ||
- (X509_verify(certificate.get(), &key_pair) != 1)) {
+ (X509_sign(certificate.get(), issuer_key, EVP_sha256()) <= 0) ||
+ (X509_verify(certificate.get(), issuer_key) != 1)) {
return nullptr;
}
@@ -80,20 +123,57 @@ bssl::UniquePtr<X509> CreateCertificateInternal(
} // namespace
-ErrorOr<bssl::UniquePtr<X509>> CreateCertificate(
+bssl::UniquePtr<EVP_PKEY> GenerateRsaKeyPair(int key_bits) {
+ bssl::UniquePtr<BIGNUM> prime(BN_new());
+ if (BN_set_word(prime.get(), RSA_F4) == 0) {
+ return nullptr;
+ }
+
+ bssl::UniquePtr<RSA> rsa(RSA_new());
+ if (RSA_generate_key_ex(rsa.get(), key_bits, prime.get(), nullptr) == 0) {
+ return nullptr;
+ }
+
+ bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new());
+ if (EVP_PKEY_set1_RSA(pkey.get(), rsa.get()) == 0) {
+ return nullptr;
+ }
+
+ return pkey;
+}
+
+ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate(
absl::string_view name,
std::chrono::seconds duration,
const EVP_PKEY& key_pair,
std::chrono::seconds time_since_unix_epoch) {
bssl::UniquePtr<X509> certificate = CreateCertificateInternal(
- name, duration, key_pair, time_since_unix_epoch);
+ name, duration, key_pair, time_since_unix_epoch, false, nullptr, nullptr);
if (!certificate) {
return Error::Code::kCertificateCreationError;
}
return certificate;
}
-ErrorOr<std::vector<uint8_t>> ExportCertificate(const X509& certificate) {
+ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest(
+ absl::string_view name,
+ std::chrono::seconds duration,
+ const EVP_PKEY& key_pair,
+ std::chrono::seconds time_since_unix_epoch,
+ bool make_ca,
+ X509* issuer,
+ EVP_PKEY* issuer_key) {
+ bssl::UniquePtr<X509> certificate =
+ CreateCertificateInternal(name, duration, key_pair, time_since_unix_epoch,
+ make_ca, issuer, issuer_key);
+ if (!certificate) {
+ return Error::Code::kCertificateCreationError;
+ }
+ return certificate;
+}
+
+ErrorOr<std::vector<uint8_t>> ExportX509CertificateToDer(
+ const X509& certificate) {
unsigned char* buffer = nullptr;
// Casting-away the const because the legacy i2d_X509() function is not
// const-correct.
@@ -121,4 +201,44 @@ ErrorOr<bssl::UniquePtr<X509>> ImportCertificate(const uint8_t* der_x509_cert,
return certificate;
}
+ErrorOr<bssl::UniquePtr<EVP_PKEY>> ImportRSAPrivateKey(
+ const uint8_t* der_rsa_private_key,
+ int key_length) {
+ if (!der_rsa_private_key || key_length == 0) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ RSA* rsa = RSA_private_key_from_bytes(der_rsa_private_key, key_length);
+ if (!rsa) {
+ return Error::Code::kRSAKeyParseError;
+ }
+ bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new());
+ EVP_PKEY_assign_RSA(pkey.get(), rsa);
+ return pkey;
+}
+
+std::string GetSpkiTlv(X509* cert) {
+ int len = i2d_X509_PUBKEY(cert->cert_info->key, nullptr);
+ if (len <= 0) {
+ return {};
+ }
+ std::string x(len, 0);
+ uint8_t* data = reinterpret_cast<uint8_t*>(&x[0]);
+ if (!i2d_X509_PUBKEY(cert->cert_info->key, &data)) {
+ return {};
+ }
+ return x;
+}
+
+ErrorOr<uint64_t> ParseDerUint64(ASN1_INTEGER* asn1int) {
+ if (asn1int->length > 8 || asn1int->length == 0) {
+ return Error::Code::kParameterInvalid;
+ }
+ uint64_t result = 0;
+ for (int i = 0; i < asn1int->length; ++i) {
+ result = (result << 8) | asn1int->data[i];
+ }
+ return result;
+}
+
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h
index 4bfe6e37a61..e60c28c3d26 100644
--- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h
+++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h
@@ -5,10 +5,12 @@
#ifndef UTIL_CRYPTO_CERTIFICATE_UTILS_H_
#define UTIL_CRYPTO_CERTIFICATE_UTILS_H_
+#include <openssl/evp.h>
#include <openssl/x509.h>
#include <stdint.h>
#include <chrono>
+#include <string>
#include <vector>
#include "absl/strings/string_view.h"
@@ -18,23 +20,50 @@
namespace openscreen {
+// Generates a new RSA key pair with bit width |key_bits|.
+bssl::UniquePtr<EVP_PKEY> GenerateRsaKeyPair(int key_bits = 2048);
+
// Creates a new self-signed X509 certificate having the given |name| and
-// |duration| until expiration, and based on the given |key_pair|.
+// |duration| until expiration, and based on the given |key_pair|, which is
+// expected to contain a valid private key.
// |time_since_unix_epoch| is the current time.
-ErrorOr<bssl::UniquePtr<X509>> CreateCertificate(
+ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate(
+ absl::string_view name,
+ std::chrono::seconds duration,
+ const EVP_PKEY& key_pair,
+ std::chrono::seconds time_since_unix_epoch = GetWallTimeSinceUnixEpoch());
+
+// Creates a new X509 certificate having the given |name| and |duration| until
+// expiration, and based on the given |key_pair|. If |issuer| and |issuer_key|
+// are provided, they are used to set the issuer information, otherwise it will
+// be self-signed. |make_ca| determines whether additional extensions are added
+// to make it a valid certificate authority cert.
+ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest(
absl::string_view name,
std::chrono::seconds duration,
const EVP_PKEY& key_pair,
- std::chrono::seconds time_since_unix_epoch =
- platform::GetWallTimeSinceUnixEpoch());
+ std::chrono::seconds time_since_unix_epoch = GetWallTimeSinceUnixEpoch(),
+ bool make_ca = false,
+ X509* issuer = nullptr,
+ EVP_PKEY* issuer_key = nullptr);
// Exports the given X509 certificate as its DER-encoded binary form.
-ErrorOr<std::vector<uint8_t>> ExportCertificate(const X509& certificate);
+ErrorOr<std::vector<uint8_t>> ExportX509CertificateToDer(
+ const X509& certificate);
// Parses a DER-encoded X509 certificate from its binary form.
ErrorOr<bssl::UniquePtr<X509>> ImportCertificate(const uint8_t* der_x509_cert,
int der_x509_cert_length);
+// Parses a DER-encoded RSAPrivateKey (RFC 3447).
+ErrorOr<bssl::UniquePtr<EVP_PKEY>> ImportRSAPrivateKey(
+ const uint8_t* der_rsa_private_key,
+ int key_length);
+
+std::string GetSpkiTlv(X509* cert);
+
+ErrorOr<uint64_t> ParseDerUint64(ASN1_INTEGER* asn1int);
+
} // namespace openscreen
#endif // UTIL_CRYPTO_CERTIFICATE_UTILS_H_
diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc b/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc
index 7475756bce4..91f6f9cbd24 100644
--- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils_unittest.cc
@@ -22,24 +22,12 @@ namespace {
constexpr char kName[] = "test.com";
constexpr auto kDuration = std::chrono::seconds(31556952);
-bssl::UniquePtr<EVP_PKEY> GenerateRsaKeypair() {
- bssl::UniquePtr<BIGNUM> prime(BN_new());
- EXPECT_NE(0, BN_set_word(prime.get(), RSA_F4));
-
- bssl::UniquePtr<RSA> rsa(RSA_new());
- EXPECT_NE(0, RSA_generate_key_ex(rsa.get(), 2048, prime.get(), nullptr));
-
- bssl::UniquePtr<EVP_PKEY> pkey(EVP_PKEY_new());
- EXPECT_NE(0, EVP_PKEY_set1_RSA(pkey.get(), rsa.get()));
-
- return pkey;
-}
-
TEST(CertificateUtilTest, CreatesValidCertificate) {
- bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeypair();
+ bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeyPair();
+ ASSERT_TRUE(pkey);
ErrorOr<bssl::UniquePtr<X509>> certificate =
- CreateCertificate(kName, kDuration, *pkey);
+ CreateSelfSignedX509Certificate(kName, kDuration, *pkey);
ASSERT_TRUE(certificate.is_value());
// Validate the generated certificate.
@@ -47,13 +35,14 @@ TEST(CertificateUtilTest, CreatesValidCertificate) {
}
TEST(CertificateUtilTest, ExportsAndImportsCertificate) {
- bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeypair();
+ bssl::UniquePtr<EVP_PKEY> pkey = GenerateRsaKeyPair();
+ ASSERT_TRUE(pkey);
ErrorOr<bssl::UniquePtr<X509>> certificate =
- CreateCertificate(kName, kDuration, *pkey);
+ CreateSelfSignedX509Certificate(kName, kDuration, *pkey);
ASSERT_TRUE(certificate.is_value());
ErrorOr<std::vector<uint8_t>> exported =
- ExportCertificate(*certificate.value());
+ ExportX509CertificateToDer(*certificate.value());
ASSERT_TRUE(exported.is_value()) << exported.error();
EXPECT_FALSE(exported.value().empty());
diff --git a/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc b/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc
new file mode 100644
index 00000000000..fa1ab24dfaa
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/crypto/digest_sign.cc
@@ -0,0 +1,31 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/crypto/digest_sign.h"
+
+namespace openscreen {
+
+ErrorOr<std::string> SignData(const EVP_MD* digest,
+ EVP_PKEY* private_key,
+ absl::Span<const uint8_t> data) {
+ bssl::ScopedEVP_MD_CTX ctx;
+ if (!EVP_DigestSignInit(ctx.get(), nullptr, digest, nullptr, private_key)) {
+ return Error::Code::kEVPInitializationError;
+ }
+ size_t signature_length = 0;
+ if ((EVP_DigestSign(ctx.get(), nullptr, &signature_length, data.data(),
+ data.size()) != 1) ||
+ signature_length == 0) {
+ return Error::Code::kEVPInitializationError;
+ }
+
+ std::string signature(signature_length, 0);
+ if (EVP_DigestSign(ctx.get(), reinterpret_cast<uint8_t*>(&signature[0]),
+ &signature_length, data.data(), data.size()) != 1) {
+ return Error::Code::kCreateSignatureFailed;
+ }
+ return signature;
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/crypto/digest_sign.h b/chromium/third_party/openscreen/src/util/crypto/digest_sign.h
new file mode 100644
index 00000000000..cd722ef7318
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/crypto/digest_sign.h
@@ -0,0 +1,23 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef UTIL_CRYPTO_DIGEST_SIGN_H_
+#define UTIL_CRYPTO_DIGEST_SIGN_H_
+
+#include <openssl/evp.h>
+
+#include <string>
+
+#include "absl/types/span.h"
+#include "platform/base/error.h"
+
+namespace openscreen {
+
+ErrorOr<std::string> SignData(const EVP_MD* digest,
+ EVP_PKEY* private_key,
+ absl::Span<const uint8_t> data);
+
+} // namespace openscreen
+
+#endif // UTIL_CRYPTO_DIGEST_SIGN_H_
diff --git a/chromium/third_party/openscreen/src/util/hashing.h b/chromium/third_party/openscreen/src/util/hashing.h
new file mode 100644
index 00000000000..55c7cf4c4d7
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/hashing.h
@@ -0,0 +1,49 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef UTIL_HASHING_H_
+#define UTIL_HASHING_H_
+
+namespace openscreen {
+
+// Computes the aggregate hash of the provided hashable objects.
+// Seed must initially use a large prime between 2^63 and 2^64 as a starting
+// value, or the result of a previous call to this function.
+template <typename... T>
+uint64_t ComputeAggregateHash(uint64_t seed, const T&... objs) {
+ auto hash_combiner = [](uint64_t seed, uint64_t hash_value) -> uint64_t {
+ static const uint64_t kMultiplier = UINT64_C(0x9ddfea08eb382d69);
+ uint64_t a = (hash_value ^ seed) * kMultiplier;
+ a ^= (a >> 47);
+ uint64_t b = (seed ^ a) * kMultiplier;
+ b ^= (b >> 47);
+ b *= kMultiplier;
+ return b;
+ };
+
+ uint64_t result = seed;
+ std::vector<uint64_t> hashes{std::hash<T>()(objs)...};
+ for (uint64_t hash : hashes) {
+ result = hash_combiner(result, hash);
+ }
+ return result;
+}
+
+template <typename... T>
+uint64_t ComputeAggregateHash(const T&... objs) {
+ // This value is taken from absl::Hash implementation.
+ constexpr uint64_t default_seed = UINT64_C(0xc3a5c85c97cb3127);
+ return ComputeAggregateHash(default_seed, objs...);
+}
+
+struct PairHash {
+ template <typename TFirst, typename TSecond>
+ size_t operator()(const std::pair<TFirst, TSecond>& pair) const {
+ return ComputeAggregateHash(pair.first, pair.second);
+ }
+};
+
+} // namespace openscreen
+
+#endif // UTIL_HASHING_H_
diff --git a/chromium/third_party/openscreen/src/util/json/json_reader.cc b/chromium/third_party/openscreen/src/util/json/json_reader.cc
deleted file mode 100644
index e98f4aa60f5..00000000000
--- a/chromium/third_party/openscreen/src/util/json/json_reader.cc
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "util/json/json_reader.h"
-
-#include <memory>
-#include <string>
-
-#include "json/value.h"
-#include "platform/base/error.h"
-#include "util/logging.h"
-
-namespace openscreen {
-
-JsonReader::JsonReader() {
- Json::CharReaderBuilder::strictMode(&builder_.settings_);
-}
-
-ErrorOr<Json::Value> JsonReader::Read(absl::string_view document) {
- if (document.empty()) {
- return ErrorOr<Json::Value>(Error::Code::kJsonParseError, "empty document");
- }
-
- Json::Value root_node;
- std::string error_msg;
- std::unique_ptr<Json::CharReader> reader(builder_.newCharReader());
- const bool succeeded =
- reader->parse(document.begin(), document.end(), &root_node, &error_msg);
- if (!succeeded) {
- return ErrorOr<Json::Value>(Error::Code::kJsonParseError, error_msg);
- }
-
- return root_node;
-}
-} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/json/json_reader.h b/chromium/third_party/openscreen/src/util/json/json_reader.h
deleted file mode 100644
index cb7cded00c7..00000000000
--- a/chromium/third_party/openscreen/src/util/json/json_reader.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef UTIL_JSON_JSON_READER_H_
-#define UTIL_JSON_JSON_READER_H_
-
-#include <memory>
-
-#include "absl/strings/string_view.h"
-#include "json/reader.h"
-
-namespace Json {
-class Value;
-}
-
-namespace openscreen {
-template <typename T>
-class ErrorOr;
-
-class JsonReader {
- public:
- JsonReader();
-
- ErrorOr<Json::Value> Read(absl::string_view document);
-
- private:
- Json::CharReaderBuilder builder_;
-};
-
-} // namespace openscreen
-
-#endif // UTIL_JSON_JSON_READER_H_ \ No newline at end of file
diff --git a/chromium/third_party/openscreen/src/util/json/json_writer.cc b/chromium/third_party/openscreen/src/util/json/json_serialization.cc
index 65130f8e58e..42ea34f49c4 100644
--- a/chromium/third_party/openscreen/src/util/json/json_writer.cc
+++ b/chromium/third_party/openscreen/src/util/json/json_serialization.cc
@@ -2,36 +2,56 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "util/json/json_writer.h"
+#include "util/json/json_serialization.h"
#include <memory>
#include <sstream>
#include <string>
#include <utility>
-#include "json/value.h"
+#include "json/reader.h"
+#include "json/writer.h"
#include "platform/base/error.h"
#include "util/logging.h"
namespace openscreen {
-JsonWriter::JsonWriter() {
-#ifndef _DEBUG
- // Default is to "pretty print" the output JSON in a human readable
- // format. On non-debug builds, we can remove pretty printing by simply
- // getting rid of all indentation.
- factory_["indentation"] = "";
-#endif
+namespace json {
+
+ErrorOr<Json::Value> Parse(absl::string_view document) {
+ Json::CharReaderBuilder builder;
+ Json::CharReaderBuilder::strictMode(&builder.settings_);
+ if (document.empty()) {
+ return ErrorOr<Json::Value>(Error::Code::kJsonParseError, "empty document");
+ }
+
+ Json::Value root_node;
+ std::string error_msg;
+ std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
+ const bool succeeded =
+ reader->parse(document.begin(), document.end(), &root_node, &error_msg);
+ if (!succeeded) {
+ return ErrorOr<Json::Value>(Error::Code::kJsonParseError, error_msg);
+ }
+
+ return root_node;
}
-ErrorOr<std::string> JsonWriter::Write(const Json::Value& value) {
+ErrorOr<std::string> Stringify(const Json::Value& value) {
if (value.empty()) {
return ErrorOr<std::string>(Error::Code::kJsonWriteError, "Empty value");
}
- std::unique_ptr<Json::StreamWriter> const writer(factory_.newStreamWriter());
- std::stringstream stream;
+ Json::StreamWriterBuilder factory;
+#ifndef _DEBUG
+ // Default is to "pretty print" the output JSON in a human readable
+ // format. On non-debug builds, we can remove pretty printing by simply
+ // getting rid of all indentation.
+ factory["indentation"] = "";
+#endif
+
+ std::unique_ptr<Json::StreamWriter> const writer(factory.newStreamWriter());
+ std::ostringstream stream;
writer->write(value, &stream);
- stream << std::endl;
if (!stream) {
// Note: jsoncpp doesn't give us more information about what actually
@@ -43,4 +63,6 @@ ErrorOr<std::string> JsonWriter::Write(const Json::Value& value) {
return stream.str();
}
+
+} // namespace json
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/json/json_serialization.h b/chromium/third_party/openscreen/src/util/json/json_serialization.h
new file mode 100644
index 00000000000..903aeb6bb1e
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/json/json_serialization.h
@@ -0,0 +1,25 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef UTIL_JSON_JSON_SERIALIZATION_H_
+#define UTIL_JSON_JSON_SERIALIZATION_H_
+
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "json/value.h"
+
+namespace openscreen {
+template <typename T>
+class ErrorOr;
+
+namespace json {
+
+ErrorOr<Json::Value> Parse(absl::string_view value);
+ErrorOr<std::string> Stringify(const Json::Value& value);
+
+} // namespace json
+} // namespace openscreen
+
+#endif // UTIL_JSON_JSON_SERIALIZATION_H_
diff --git a/chromium/third_party/openscreen/src/util/json/json_reader_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_serialization_unittest.cc
index a18cf2010d5..b94fe0e38a0 100644
--- a/chromium/third_party/openscreen/src/util/json/json_reader_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/json/json_serialization_unittest.cc
@@ -2,11 +2,10 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "util/json/json_reader.h"
+#include "util/json/json_serialization.h"
#include <string>
-#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/base/error.h"
@@ -18,21 +17,17 @@ void AssertError(ErrorOr<Value> error_or, Error::Code code) {
}
} // namespace
-TEST(JsonReaderTest, MalformedDocumentReturnsParseError) {
- JsonReader reader;
-
+TEST(JsonSerializationTest, MalformedDocumentReturnsParseError) {
const std::array<std::string, 4> kMalformedDocuments{
{"", "{", "{ foo: bar }", R"({"foo": "bar", "foo": baz})"}};
for (auto& document : kMalformedDocuments) {
- AssertError(reader.Read(document), Error::Code::kJsonParseError);
+ AssertError(json::Parse(document), Error::Code::kJsonParseError);
}
}
-TEST(JsonReaderTest, ValidEmptyDocumentParsedCorrectly) {
- JsonReader reader;
-
- const auto actual = reader.Read("{}");
+TEST(JsonSerializationTest, ValidEmptyDocumentParsedCorrectly) {
+ const auto actual = json::Parse("{}");
EXPECT_TRUE(actual.is_value());
EXPECT_EQ(actual.value().getMemberNames().size(), 0u);
@@ -41,13 +36,26 @@ TEST(JsonReaderTest, ValidEmptyDocumentParsedCorrectly) {
// Jsoncpp has its own suite of tests ensure that things are parsed correctly,
// so we only do some rudimentary checks here to make sure we didn't mangle
// the value.
-TEST(JsonReaderTest, ValidDocumentParsedCorrectly) {
- JsonReader reader;
-
- const auto actual = reader.Read(R"({"foo": "bar", "baz": 1337})");
+TEST(JsonSerializationTest, ValidDocumentParsedCorrectly) {
+ const auto actual = json::Parse(R"({"foo": "bar", "baz": 1337})");
EXPECT_TRUE(actual.is_value());
EXPECT_EQ(actual.value().getMemberNames().size(), 2u);
}
+TEST(JsonSerializationTest, NullValueReturnsError) {
+ const auto null_value = Json::Value();
+ const auto actual = json::Stringify(null_value);
+
+ EXPECT_TRUE(actual.is_error());
+ EXPECT_EQ(actual.error().code(), Error::Code::kJsonWriteError);
+}
+
+TEST(JsonSerializationTest, ValidValueReturnsString) {
+ const Json::Int64 value = 31337;
+ const auto actual = json::Stringify(value);
+
+ EXPECT_TRUE(actual.is_value());
+ EXPECT_EQ(actual.value(), "31337");
+}
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/json/json_value.cc b/chromium/third_party/openscreen/src/util/json/json_value.cc
new file mode 100644
index 00000000000..cfde5f84cf8
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/json/json_value.cc
@@ -0,0 +1,43 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/json/json_value.h"
+
+namespace openscreen {
+
+absl::optional<int> MaybeGetInt(const Json::Value& message,
+ const char* first,
+ const char* last) {
+ const Json::Value* value = message.find(first, last);
+ absl::optional<int> result;
+ if (value && value->isInt()) {
+ result = value->asInt();
+ }
+ return result;
+}
+
+absl::optional<absl::string_view> MaybeGetString(const Json::Value& message) {
+ if (message.isString()) {
+ const char* begin = nullptr;
+ const char* end = nullptr;
+ message.getString(&begin, &end);
+ if (begin && end >= begin) {
+ return absl::string_view(begin, end - begin);
+ }
+ }
+ return absl::nullopt;
+}
+
+absl::optional<absl::string_view> MaybeGetString(const Json::Value& message,
+ const char* first,
+ const char* last) {
+ const Json::Value* value = message.find(first, last);
+ absl::optional<absl::string_view> result;
+ if (value && value->isString()) {
+ return MaybeGetString(*value);
+ }
+ return result;
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/json/json_value.h b/chromium/third_party/openscreen/src/util/json/json_value.h
new file mode 100644
index 00000000000..d41ea27a932
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/json/json_value.h
@@ -0,0 +1,28 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef UTIL_JSON_JSON_VALUE_H_
+#define UTIL_JSON_JSON_VALUE_H_
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "json/value.h"
+
+#define JSON_EXPAND_FIND_CONSTANT_ARGS(s) (s), ((s) + sizeof(s) - 1)
+
+namespace openscreen {
+
+absl::optional<int> MaybeGetInt(const Json::Value& message,
+ const char* first,
+ const char* last);
+
+absl::optional<absl::string_view> MaybeGetString(const Json::Value& message);
+
+absl::optional<absl::string_view> MaybeGetString(const Json::Value& message,
+ const char* first,
+ const char* last);
+
+} // namespace openscreen
+
+#endif // UTIL_JSON_JSON_VALUE_H_
diff --git a/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc
new file mode 100644
index 00000000000..04b57456279
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/json/json_value_unittest.cc
@@ -0,0 +1,55 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/json/json_value.h"
+
+#include "gtest/gtest.h"
+#include "platform/base/error.h"
+#include "util/json/json_serialization.h"
+
+namespace openscreen {
+
+TEST(JsonValueTest, GetInt) {
+ absl::string_view obj(R"!({"key1": 17, "key2": 32.3, "key3": "asdf"})!");
+ ErrorOr<Json::Value> value_or_error = json::Parse(obj);
+ ASSERT_TRUE(value_or_error);
+ Json::Value& value = value_or_error.value();
+ absl::optional<int> result1 =
+ MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key1"));
+ absl::optional<int> result2 =
+ MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key2"));
+ absl::optional<int> result3 =
+ MaybeGetInt(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key42"));
+ EXPECT_FALSE(result2);
+ EXPECT_FALSE(result3);
+
+ ASSERT_TRUE(result1);
+ EXPECT_EQ(result1.value(), 17);
+}
+
+TEST(JsonValueTest, GetString) {
+ absl::string_view obj(
+ R"!({"key1": 17, "key2": 32.3, "key3": "asdf", "key4": ""})!");
+ ErrorOr<Json::Value> value_or_error = json::Parse(obj);
+ ASSERT_TRUE(value_or_error);
+ Json::Value& value = value_or_error.value();
+ absl::optional<absl::string_view> result1 =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key3"));
+ absl::optional<absl::string_view> result2 =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key2"));
+ absl::optional<absl::string_view> result3 =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key42"));
+ absl::optional<absl::string_view> result4 =
+ MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS("key4"));
+
+ EXPECT_FALSE(result2);
+ EXPECT_FALSE(result3);
+
+ ASSERT_TRUE(result1);
+ EXPECT_EQ(result1.value(), "asdf");
+ ASSERT_TRUE(result4);
+ EXPECT_EQ(result4.value(), "");
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/json/json_writer.h b/chromium/third_party/openscreen/src/util/json/json_writer.h
deleted file mode 100644
index df37d9a067c..00000000000
--- a/chromium/third_party/openscreen/src/util/json/json_writer.h
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef UTIL_JSON_JSON_WRITER_H_
-#define UTIL_JSON_JSON_WRITER_H_
-
-#include <memory>
-#include <string>
-
-#include "absl/strings/string_view.h"
-#include "json/writer.h"
-
-namespace Json {
-class Value;
-}
-
-namespace openscreen {
-template <typename T>
-class ErrorOr;
-
-class JsonWriter {
- public:
- JsonWriter();
-
- ErrorOr<std::string> Write(const Json::Value& value);
-
- private:
- Json::StreamWriterBuilder factory_;
-};
-
-} // namespace openscreen
-
-#endif // UTIL_JSON_JSON_WRITER_H_
diff --git a/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc
deleted file mode 100644
index 8b75c82e0fc..00000000000
--- a/chromium/third_party/openscreen/src/util/json/json_writer_unittest.cc
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "util/json/json_writer.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "platform/base/error.h"
-
-namespace openscreen {
-
-TEST(JsonWriterTest, NullValueReturnsError) {
- JsonWriter writer;
-
- const auto null_value = Json::Value();
- const auto actual = writer.Write(null_value);
-
- EXPECT_TRUE(actual.is_error());
- EXPECT_EQ(actual.error().code(), Error::Code::kJsonWriteError);
-}
-
-TEST(JsonWriterTest, ValidValueReturnsString) {
- JsonWriter writer;
-
- const Json::Int64 value = 31337;
- const auto actual = writer.Write(value);
-
- EXPECT_TRUE(actual.is_value());
- EXPECT_EQ(actual.value(), "31337\n");
-}
-} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/logging.h b/chromium/third_party/openscreen/src/util/logging.h
index 5487011189f..4c76040e684 100644
--- a/chromium/third_party/openscreen/src/util/logging.h
+++ b/chromium/third_party/openscreen/src/util/logging.h
@@ -6,11 +6,11 @@
#define UTIL_LOGGING_H_
#include <sstream>
+#include <utility>
#include "platform/api/logging.h"
namespace openscreen {
-namespace platform {
namespace internal {
// The stream-based logging macros below are adapted from Chromium's
@@ -21,7 +21,7 @@ class LogMessage {
: level_(level), file_(file), line_(line) {}
~LogMessage() {
- LogWithLevel(level_, file_, line_, stream_.str());
+ LogWithLevel(level_, file_, line_, std::move(stream_));
if (level_ == LogLevel::kFatal) {
Break();
}
@@ -37,7 +37,7 @@ class LogMessage {
// creating a copy should be safe.
const char* const file_;
const int line_;
- std::ostringstream stream_;
+ std::stringstream stream_;
};
// Used by the OSP_LAZY_STREAM macro to return void after evaluating an ostream
@@ -48,17 +48,15 @@ class Voidify {
};
} // namespace internal
-} // namespace platform
} // namespace openscreen
#define OSP_LAZY_STREAM(condition, stream) \
- !(condition) ? (void)0 : openscreen::platform::internal::Voidify() & (stream)
-#define OSP_LOG_IS_ON(level_enum) \
- openscreen::platform::IsLoggingOn( \
- openscreen::platform::LogLevel::level_enum, __FILE__)
-#define OSP_LOG_STREAM(level_enum) \
- openscreen::platform::internal::LogMessage( \
- openscreen::platform::LogLevel::level_enum, __FILE__, __LINE__) \
+ !(condition) ? (void)0 : openscreen::internal::Voidify() & (stream)
+#define OSP_LOG_IS_ON(level_enum) \
+ openscreen::IsLoggingOn(openscreen::LogLevel::level_enum, __FILE__)
+#define OSP_LOG_STREAM(level_enum) \
+ openscreen::internal::LogMessage(openscreen::LogLevel::level_enum, __FILE__, \
+ __LINE__) \
.stream()
#define OSP_VLOG \
diff --git a/chromium/third_party/openscreen/src/util/operation_loop.h b/chromium/third_party/openscreen/src/util/operation_loop.h
index 155fe0783d7..ddb7846f4be 100644
--- a/chromium/third_party/openscreen/src/util/operation_loop.h
+++ b/chromium/third_party/openscreen/src/util/operation_loop.h
@@ -15,8 +15,6 @@
namespace openscreen {
-using Clock = platform::Clock;
-
class OperationLoop {
public:
using OperationWithTimeout = std::function<void(Clock::duration)>;
diff --git a/chromium/third_party/openscreen/src/util/saturate_cast.h b/chromium/third_party/openscreen/src/util/saturate_cast.h
index db40660eeda..1bb373812f9 100644
--- a/chromium/third_party/openscreen/src/util/saturate_cast.h
+++ b/chromium/third_party/openscreen/src/util/saturate_cast.h
@@ -5,11 +5,21 @@
#ifndef UTIL_SATURATE_CAST_H_
#define UTIL_SATURATE_CAST_H_
+#include <cmath>
#include <limits>
#include <type_traits>
namespace openscreen {
+// Case 0: When To and From are the same type, saturate_cast<> is pass-through.
+template <typename To, typename From>
+constexpr std::enable_if_t<
+ std::is_same<std::remove_cv<To>, std::remove_cv<From>>::value,
+ To>
+saturate_cast(From from) {
+ return from;
+}
+
// Because of the way C++ signed versus unsigned comparison works (i.e., the
// type promotion strategy employed), extra care must be taken to range-check
// the input value. For example, if the current architecture is 32-bits, then
@@ -21,7 +31,7 @@ namespace openscreen {
// this case, the smaller of the two types will be promoted to match the
// larger's size, and a valid comparison will be made.
template <typename To, typename From>
-constexpr typename std::enable_if_t<
+constexpr std::enable_if_t<
std::is_integral<From>::value && std::is_integral<To>::value &&
(std::is_signed<From>::value == std::is_signed<To>::value),
To>
@@ -37,7 +47,7 @@ saturate_cast(From from) {
// Case 2: "From" is signed, but "To" is unsigned.
template <typename To, typename From>
-constexpr typename std::enable_if_t<
+constexpr std::enable_if_t<
std::is_integral<From>::value && std::is_integral<To>::value &&
std::is_signed<From>::value && !std::is_signed<To>::value,
To>
@@ -45,7 +55,7 @@ saturate_cast(From from) {
if (from <= From{0}) {
return To{0};
}
- if (static_cast<typename std::make_unsigned_t<From>>(from) >=
+ if (static_cast<std::make_unsigned_t<From>>(from) >=
std::numeric_limits<To>::max()) {
return std::numeric_limits<To>::max();
}
@@ -54,7 +64,7 @@ saturate_cast(From from) {
// Case 3: "From" is unsigned, but "To" is signed.
template <typename To, typename From>
-constexpr typename std::enable_if_t<
+constexpr std::enable_if_t<
std::is_integral<From>::value && std::is_integral<To>::value &&
!std::is_signed<From>::value && std::is_signed<To>::value,
To>
@@ -66,6 +76,70 @@ saturate_cast(From from) {
return static_cast<To>(from);
}
+// Case 4: "From" is a floating-point type, and "To" is an integer type (signed
+// or unsigned). The result is truncated, per the usual C++ float-to-int
+// conversion rules.
+template <typename To, typename From>
+constexpr std::enable_if_t<std::is_floating_point<From>::value &&
+ std::is_integral<To>::value,
+ To>
+saturate_cast(From from) {
+ // Note: It's invalid to compare the argument against
+ // std::numeric_limits<To>::max() because the latter, an integer value, will
+ // be type-promoted to the floating-point type. The problem is that the
+ // conversion is imprecise, as "max int" might not be exactly representable as
+ // a floating-point value (depending on the actual types of From and To).
+ //
+ // Thus, the strategy is to compare only floating-point values/constants to
+ // determine whether the bounds of the range of integers has been exceeded.
+ // Two assumptions here: 1) "To" is either unsigned, or is a 2's complement
+ // signed integer type. 2) "From" is a floating-point type that can exactly
+ // represent all powers of 2 within its value range.
+ static_assert((~To(1) + To(1)) == To(-1), "assumed 2's complement integers");
+ constexpr From kMaxIntPlusOne =
+ From(To(1) << (std::numeric_limits<To>::digits - 1)) * From(2);
+ constexpr From kMaxInt = kMaxIntPlusOne - 1;
+ // Note: In some cases, the kMaxInt constant will equal kMaxIntPlusOne because
+ // there isn't an exact floating-point representation for 2^N - 1. That said,
+ // the following upper-bound comparison is still valid because all
+ // floating-point values less than 2^N would also be less than 2^N - 1.
+ if (from >= kMaxInt) {
+ return std::numeric_limits<To>::max();
+ }
+ if (std::is_signed<To>::value) {
+ constexpr From kMinInt = -kMaxIntPlusOne;
+ if (from <= kMinInt) {
+ return std::numeric_limits<To>::min();
+ }
+ } else /* if To is unsigned */ {
+ if (from <= From(0)) {
+ return To(0);
+ }
+ }
+ return static_cast<To>(from);
+}
+
+// Like saturate_cast<>, but rounds to the nearest integer instead of
+// truncating.
+template <typename To, typename From>
+constexpr std::enable_if_t<std::is_floating_point<From>::value &&
+ std::is_integral<To>::value,
+ To>
+rounded_saturate_cast(From from) {
+ const To saturated = saturate_cast<To>(from);
+ if (saturated == std::numeric_limits<To>::min() ||
+ saturated == std::numeric_limits<To>::max()) {
+ return saturated;
+ }
+
+ static_assert(sizeof(To) <= sizeof(decltype(llround(from))),
+ "No version of lround() for the required range of values.");
+ if (sizeof(To) <= sizeof(decltype(lround(from)))) {
+ return static_cast<To>(lround(from));
+ }
+ return static_cast<To>(llround(from));
+}
+
} // namespace openscreen
#endif // UTIL_SATURATE_CAST_H_
diff --git a/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc b/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc
index 026d238c126..293578a09e9 100644
--- a/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/saturate_cast_unittest.cc
@@ -175,5 +175,188 @@ TEST(SaturateCastTest, UnsignedToSigned64BitInteger) {
}
}
+TEST(SaturateCastTest, Float32ToSigned32) {
+ struct ValuePair {
+ float from;
+ int32_t to;
+ };
+ constexpr float kFloatMax = std::numeric_limits<float>::max();
+ // Note: kIntMax is one larger because float cannot represent the exact value.
+ constexpr float kIntMax =
+ static_cast<float>(std::numeric_limits<int32_t>::max());
+ constexpr float kIntMin = std::numeric_limits<int32_t>::min();
+ const ValuePair kValuePairs[] = {
+ {kFloatMax, std::numeric_limits<int32_t>::max()},
+ {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int32_t>::max()},
+ {kIntMax, std::numeric_limits<int32_t>::max()},
+ {std::nextafter(kIntMax, 0.f), 2147483520},
+ {42, 42},
+ {0, 0},
+ {-42, -42},
+ {std::nextafter(kIntMin, 0.f), -2147483520},
+ {kIntMin, std::numeric_limits<int32_t>::min()},
+ {std::nextafter(kIntMin, -kFloatMax),
+ std::numeric_limits<int32_t>::min()},
+ {-kFloatMax, std::numeric_limits<int32_t>::min()},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, saturate_cast<int32_t>(value_pair.from));
+ }
+}
+
+TEST(SaturateCastTest, Float32ToSigned64) {
+ struct ValuePair {
+ float from;
+ int64_t to;
+ };
+ constexpr float kFloatMax = std::numeric_limits<float>::max();
+ // Note: kIntMax is one larger because float cannot represent the exact value.
+ constexpr float kIntMax =
+ static_cast<float>(std::numeric_limits<int64_t>::max());
+ constexpr float kIntMin = std::numeric_limits<int64_t>::min();
+ const ValuePair kValuePairs[] = {
+ {kFloatMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int64_t>::max()},
+ {kIntMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, 0.f), INT64_C(9223371487098961920)},
+ {42, 42},
+ {0, 0},
+ {-42, -42},
+ {std::nextafter(kIntMin, 0.f), INT64_C(-9223371487098961920)},
+ {kIntMin, std::numeric_limits<int64_t>::min()},
+ {std::nextafter(kIntMin, -kFloatMax),
+ std::numeric_limits<int64_t>::min()},
+ {-kFloatMax, std::numeric_limits<int64_t>::min()},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, saturate_cast<int64_t>(value_pair.from));
+ }
+}
+
+TEST(SaturateCastTest, Float64ToSigned32) {
+ struct ValuePair {
+ double from;
+ int32_t to;
+ };
+ constexpr double kDoubleMax = std::numeric_limits<double>::max();
+ constexpr double kIntMax = std::numeric_limits<int32_t>::max();
+ constexpr double kIntMin = std::numeric_limits<int32_t>::min();
+ const ValuePair kValuePairs[] = {
+ {kDoubleMax, std::numeric_limits<int32_t>::max()},
+ {std::nextafter(kIntMax, kDoubleMax),
+ std::numeric_limits<int32_t>::max()},
+ {kIntMax, std::numeric_limits<int32_t>::max()},
+ {std::nextafter(kIntMax, 0.0), std::numeric_limits<int32_t>::max() - 1},
+ {42, 42},
+ {0, 0},
+ {-42, -42},
+ {std::nextafter(kIntMin, 0.0), std::numeric_limits<int32_t>::min() + 1},
+ {kIntMin, std::numeric_limits<int32_t>::min()},
+ {std::nextafter(kIntMin, -kDoubleMax),
+ std::numeric_limits<int32_t>::min()},
+ {-kDoubleMax, std::numeric_limits<int32_t>::min()},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, saturate_cast<int32_t>(value_pair.from));
+ }
+}
+
+TEST(SaturateCastTest, Float64ToSigned64) {
+ struct ValuePair {
+ double from;
+ int64_t to;
+ };
+ constexpr double kDoubleMax = std::numeric_limits<double>::max();
+ // Note: kIntMax is one larger because double cannot represent the exact
+ // value.
+ constexpr double kIntMax =
+ static_cast<double>(std::numeric_limits<int64_t>::max());
+ constexpr double kIntMin = std::numeric_limits<int64_t>::min();
+ const ValuePair kValuePairs[] = {
+ {kDoubleMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, kDoubleMax),
+ std::numeric_limits<int64_t>::max()},
+ {kIntMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, 0.0), INT64_C(9223372036854774784)},
+ {42, 42},
+ {0, 0},
+ {-42, -42},
+ {std::nextafter(kIntMin, 0.0), INT64_C(-9223372036854774784)},
+ {kIntMin, std::numeric_limits<int64_t>::min()},
+ {std::nextafter(kIntMin, -kDoubleMax),
+ std::numeric_limits<int64_t>::min()},
+ {-kDoubleMax, std::numeric_limits<int64_t>::min()},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, saturate_cast<int64_t>(value_pair.from));
+ }
+}
+
+TEST(SaturateCastTest, Float32ToUnsigned64) {
+ struct ValuePair {
+ float from;
+ uint64_t to;
+ };
+ constexpr float kFloatMax = std::numeric_limits<float>::max();
+ // Note: kIntMax is one larger because float cannot represent the exact value.
+ constexpr float kIntMax =
+ static_cast<float>(std::numeric_limits<uint64_t>::max());
+ const ValuePair kValuePairs[] = {
+ {kFloatMax, std::numeric_limits<uint64_t>::max()},
+ {std::nextafter(kIntMax, kFloatMax),
+ std::numeric_limits<uint64_t>::max()},
+ {kIntMax, std::numeric_limits<uint64_t>::max()},
+ {std::nextafter(kIntMax, 0.f), UINT64_C(18446742974197923840)},
+ {42, 42},
+ {0, 0},
+ {-42, 0},
+ {-kFloatMax, 0},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, saturate_cast<uint64_t>(value_pair.from));
+ }
+}
+
+TEST(SaturateCastTest, RoundingFloat32ToSigned64) {
+ struct ValuePair {
+ float from;
+ int64_t to;
+ };
+ constexpr float kFloatMax = std::numeric_limits<float>::max();
+ // Note: kIntMax is one larger because float cannot represent the exact value.
+ constexpr float kIntMax =
+ static_cast<float>(std::numeric_limits<int64_t>::max());
+ constexpr float kIntMin = std::numeric_limits<int64_t>::min();
+ const ValuePair kValuePairs[] = {
+ {kFloatMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, kFloatMax), std::numeric_limits<int64_t>::max()},
+ {kIntMax, std::numeric_limits<int64_t>::max()},
+ {std::nextafter(kIntMax, 0.f), INT64_C(9223371487098961920)},
+ {41.9, 42},
+ {42, 42},
+ {42.6, 43},
+ {42.5, 43},
+ {42.4, 42},
+ {0.5, 1},
+ {0.1, 0},
+ {0, 0},
+ {-0.1, 0},
+ {-0.5, -1},
+ {-41.9, -42},
+ {-42, -42},
+ {-42.4, -42},
+ {-42.5, -43},
+ {-42.6, -43},
+ {std::nextafter(kIntMin, 0.f), INT64_C(-9223371487098961920)},
+ {kIntMin, std::numeric_limits<int64_t>::min()},
+ {std::nextafter(kIntMin, -kFloatMax),
+ std::numeric_limits<int64_t>::min()},
+ {-kFloatMax, std::numeric_limits<int64_t>::min()},
+ };
+ for (const ValuePair& value_pair : kValuePairs) {
+ EXPECT_EQ(value_pair.to, rounded_saturate_cast<int64_t>(value_pair.from));
+ }
+}
+
} // namespace
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/serial_delete_ptr.h b/chromium/third_party/openscreen/src/util/serial_delete_ptr.h
index d33781ee361..1aff24959d5 100644
--- a/chromium/third_party/openscreen/src/util/serial_delete_ptr.h
+++ b/chromium/third_party/openscreen/src/util/serial_delete_ptr.h
@@ -19,24 +19,30 @@ namespace openscreen {
template <typename Type, typename DeleterType>
class SerialDelete {
public:
- explicit SerialDelete(platform::TaskRunner* task_runner)
+ SerialDelete() : deleter_() {}
+
+ explicit SerialDelete(TaskRunner* task_runner)
: task_runner_(task_runner), deleter_() {
assert(task_runner);
}
template <typename DT>
- SerialDelete(platform::TaskRunner* task_runner, DT&& deleter)
+ SerialDelete(TaskRunner* task_runner, DT&& deleter)
: task_runner_(task_runner), deleter_(std::forward<DT>(deleter)) {
assert(task_runner);
}
void operator()(Type* pointer) const {
- // Deletion of the object depends on the task being run by the task runner.
- task_runner_->PostTask([pointer, deleter = deleter_] { deleter(pointer); });
+ if (task_runner_) {
+ // Deletion of the object depends on the task being run by the task
+ // runner.
+ task_runner_->PostTask(
+ [pointer, deleter = std::move(deleter_)] { deleter(pointer); });
+ }
}
private:
- platform::TaskRunner* task_runner_;
+ TaskRunner* task_runner_;
DeleterType deleter_;
};
@@ -46,21 +52,26 @@ template <typename Type, typename DeleterType = std::default_delete<Type>>
class SerialDeletePtr
: public std::unique_ptr<Type, SerialDelete<Type, DeleterType>> {
public:
- explicit SerialDeletePtr(platform::TaskRunner* task_runner) noexcept
+ SerialDeletePtr() noexcept
+ : std::unique_ptr<Type, SerialDelete<Type, DeleterType>>(
+ nullptr,
+ SerialDelete<Type, DeleterType>()) {}
+
+ explicit SerialDeletePtr(TaskRunner* task_runner) noexcept
: std::unique_ptr<Type, SerialDelete<Type, DeleterType>>(
nullptr,
SerialDelete<Type, DeleterType>(task_runner)) {
assert(task_runner);
}
- SerialDeletePtr(platform::TaskRunner* task_runner, std::nullptr_t) noexcept
+ SerialDeletePtr(TaskRunner* task_runner, std::nullptr_t) noexcept
: std::unique_ptr<Type, SerialDelete<Type, DeleterType>>(
nullptr,
SerialDelete<Type, DeleterType>(task_runner)) {
assert(task_runner);
}
- SerialDeletePtr(platform::TaskRunner* task_runner, Type* pointer) noexcept
+ SerialDeletePtr(TaskRunner* task_runner, Type* pointer) noexcept
: std::unique_ptr<Type, SerialDelete<Type, DeleterType>>(
pointer,
SerialDelete<Type, DeleterType>(task_runner)) {
@@ -68,7 +79,7 @@ class SerialDeletePtr
}
SerialDeletePtr(
- platform::TaskRunner* task_runner,
+ TaskRunner* task_runner,
Type* pointer,
typename std::conditional<std::is_reference<DeleterType>::value,
DeleterType,
@@ -80,7 +91,7 @@ class SerialDeletePtr
}
SerialDeletePtr(
- platform::TaskRunner* task_runner,
+ TaskRunner* task_runner,
Type* pointer,
typename std::remove_reference<DeleterType>::type&& deleter) noexcept
: std::unique_ptr<Type, SerialDelete<Type, DeleterType>>(
@@ -91,7 +102,7 @@ class SerialDeletePtr
};
template <typename Type, typename... Args>
-SerialDeletePtr<Type> MakeSerialDelete(platform::TaskRunner* task_runner,
+SerialDeletePtr<Type> MakeSerialDelete(TaskRunner* task_runner,
Args&&... args) {
return SerialDeletePtr<Type>(task_runner,
new Type(std::forward<Args>(args)...));
diff --git a/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc b/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc
index 4713814cd85..c8b899c5d36 100644
--- a/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/serial_delete_ptr_unittest.cc
@@ -10,10 +10,6 @@
namespace openscreen {
-using openscreen::platform::Clock;
-using openscreen::platform::FakeClock;
-using openscreen::platform::FakeTaskRunner;
-
class SerialDeletePtrTest : public ::testing::Test {
public:
SerialDeletePtrTest() : clock_(Clock::now()), task_runner_(&clock_) {}
diff --git a/chromium/third_party/openscreen/src/util/simple_fraction.cc b/chromium/third_party/openscreen/src/util/simple_fraction.cc
new file mode 100644
index 00000000000..b45e1982eff
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/simple_fraction.cc
@@ -0,0 +1,69 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/simple_fraction.h"
+
+#include <cmath>
+#include <limits>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "util/logging.h"
+
+namespace openscreen {
+
+// static
+ErrorOr<SimpleFraction> SimpleFraction::FromString(absl::string_view value) {
+ std::vector<absl::string_view> fields = absl::StrSplit(value, '/');
+ if (fields.size() != 1 && fields.size() != 2) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ int numerator;
+ int denominator = 1;
+ if (!absl::SimpleAtoi(fields[0], &numerator)) {
+ return Error::Code::kParameterInvalid;
+ }
+
+ if (fields.size() == 2) {
+ if (!absl::SimpleAtoi(fields[1], &denominator)) {
+ return Error::Code::kParameterInvalid;
+ }
+ }
+
+ return SimpleFraction{numerator, denominator};
+}
+
+std::string SimpleFraction::ToString() const {
+ if (denominator == 1) {
+ return std::to_string(numerator);
+ }
+ return absl::StrCat(numerator, "/", denominator);
+}
+
+bool SimpleFraction::operator==(const SimpleFraction& other) const {
+ return numerator == other.numerator && denominator == other.denominator;
+}
+
+bool SimpleFraction::operator!=(const SimpleFraction& other) const {
+ return !(*this == other);
+}
+
+bool SimpleFraction::is_defined() const {
+ return denominator != 0;
+}
+
+bool SimpleFraction::is_positive() const {
+ return is_defined() && (numerator >= 0) && (denominator > 0);
+}
+
+SimpleFraction::operator double() const {
+ if (denominator == 0) {
+ return nan("");
+ }
+ return static_cast<double>(numerator) / static_cast<double>(denominator);
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/simple_fraction.h b/chromium/third_party/openscreen/src/util/simple_fraction.h
new file mode 100644
index 00000000000..f8ab50832f9
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/simple_fraction.h
@@ -0,0 +1,45 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef UTIL_SIMPLE_FRACTION_H_
+#define UTIL_SIMPLE_FRACTION_H_
+
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "platform/base/error.h"
+
+namespace openscreen {
+
+// SimpleFraction is used to represent simple (or "common") fractions, composed
+// of a rational number written a/b where a and b are both integers.
+
+// Note: Since SimpleFraction is a trivial type, it comes with a
+// default constructor and is copyable, as well as allowing static
+// initialization.
+
+// Some helpful notes on SimpleFraction assumptions/limitations:
+// 1. SimpleFraction does not perform reductions. 2/4 != 1/2, and -1/-1 != 1/1.
+// 2. denominator = 0 is considered undefined.
+// 3. numerator = saturates range to int min or int max
+// 4. A SimpleFraction is "positive" if and only if it is defined and at least
+// equal to zero. Since reductions are not performed, -1/-1 is negative.
+struct SimpleFraction {
+ static ErrorOr<SimpleFraction> FromString(absl::string_view value);
+ std::string ToString() const;
+
+ bool operator==(const SimpleFraction& other) const;
+ bool operator!=(const SimpleFraction& other) const;
+
+ bool is_defined() const;
+ bool is_positive() const;
+ explicit operator double() const;
+
+ int numerator = 0;
+ int denominator = 0;
+};
+
+} // namespace openscreen
+
+#endif // UTIL_SIMPLE_FRACTION_H_
diff --git a/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc
new file mode 100644
index 00000000000..7cdbfeeccad
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc
@@ -0,0 +1,98 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/simple_fraction.h"
+
+#include <limits>
+
+#include "gtest/gtest.h"
+
+namespace openscreen {
+
+namespace {
+
+constexpr int kMin = std::numeric_limits<int>::min();
+constexpr int kMax = std::numeric_limits<int>::max();
+
+void ExpectFromStringEquals(absl::string_view s,
+ const SimpleFraction& expected) {
+ const ErrorOr<SimpleFraction> f = SimpleFraction::FromString(s);
+ EXPECT_TRUE(f.is_value());
+ EXPECT_EQ(expected, f.value());
+}
+
+void ExpectFromStringError(absl::string_view s) {
+ const auto f = SimpleFraction::FromString(s);
+ EXPECT_TRUE(f.is_error());
+}
+} // namespace
+
+TEST(SimpleFractionTest, FromStringParsesCorrectFractions) {
+ ExpectFromStringEquals("1/2", SimpleFraction{1, 2});
+ ExpectFromStringEquals("99/3", SimpleFraction{99, 3});
+ ExpectFromStringEquals("-1/2", SimpleFraction{-1, 2});
+ ExpectFromStringEquals("-13/-37", SimpleFraction{-13, -37});
+ ExpectFromStringEquals("1/0", SimpleFraction{1, 0});
+ ExpectFromStringEquals("1", SimpleFraction{1, 1});
+ ExpectFromStringEquals("0", SimpleFraction{0, 1});
+ ExpectFromStringEquals("-20", SimpleFraction{-20, 1});
+ ExpectFromStringEquals("100", SimpleFraction{100, 1});
+}
+
+TEST(SimpleFractionTest, FromStringErrorsOnInvalid) {
+ ExpectFromStringError("");
+ ExpectFromStringError("/");
+ ExpectFromStringError("1/");
+ ExpectFromStringError("/1");
+ ExpectFromStringError("888/");
+ ExpectFromStringError("not a fraction at all");
+}
+
+TEST(SimpleFractionTest, Equality) {
+ EXPECT_EQ((SimpleFraction{1, 2}), (SimpleFraction{1, 2}));
+ EXPECT_EQ((SimpleFraction{1, 0}), (SimpleFraction{1, 0}));
+ EXPECT_NE((SimpleFraction{1, 2}), (SimpleFraction{1, 3}));
+
+ // We currently don't do any reduction.
+ EXPECT_NE((SimpleFraction{2, 4}), (SimpleFraction{1, 2}));
+ EXPECT_NE((SimpleFraction{9, 10}), (SimpleFraction{-9, -10}));
+}
+
+TEST(SimpleFractionTest, Definition) {
+ EXPECT_TRUE((SimpleFraction{kMin, 1}).is_defined());
+ EXPECT_TRUE((SimpleFraction{kMax, 1}).is_defined());
+
+ EXPECT_FALSE((SimpleFraction{kMin, 0}).is_defined());
+ EXPECT_FALSE((SimpleFraction{kMax, 0}).is_defined());
+ EXPECT_FALSE((SimpleFraction{0, 0}).is_defined());
+ EXPECT_FALSE((SimpleFraction{-0, -0}).is_defined());
+}
+
+TEST(SimpleFractionTest, Positivity) {
+ EXPECT_TRUE((SimpleFraction{1234, 20}).is_positive());
+ EXPECT_TRUE((SimpleFraction{kMax - 1, 20}).is_positive());
+ EXPECT_TRUE((SimpleFraction{0, kMax}).is_positive());
+ EXPECT_TRUE((SimpleFraction{kMax, 1}).is_positive());
+
+ // Since C++ doesn't have a truly negative zero, this is positive.
+ EXPECT_TRUE((SimpleFraction{-0, 1}).is_positive());
+
+ EXPECT_FALSE((SimpleFraction{0, kMin}).is_positive());
+ EXPECT_FALSE((SimpleFraction{-0, -1}).is_positive());
+ EXPECT_FALSE((SimpleFraction{kMin + 1, 20}).is_positive());
+ EXPECT_FALSE((SimpleFraction{kMin, 1}).is_positive());
+ EXPECT_FALSE((SimpleFraction{kMin, 0}).is_positive());
+ EXPECT_FALSE((SimpleFraction{kMax, 0}).is_positive());
+ EXPECT_FALSE((SimpleFraction{0, 0}).is_positive());
+ EXPECT_FALSE((SimpleFraction{-0, -0}).is_positive());
+}
+
+TEST(SimpleFractionTest, CastToDouble) {
+ EXPECT_DOUBLE_EQ(0.0, static_cast<double>(SimpleFraction{0, 1}));
+ EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{1, 1}));
+ EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMax, kMax}));
+ EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMin, kMin}));
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/stringprintf.cc b/chromium/third_party/openscreen/src/util/stringprintf.cc
new file mode 100644
index 00000000000..452c32824fc
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/stringprintf.cc
@@ -0,0 +1,21 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/stringprintf.h"
+
+#include <iomanip>
+#include <sstream>
+
+namespace openscreen {
+
+std::string HexEncode(absl::Span<const uint8_t> bytes) {
+ std::ostringstream hex_dump;
+ hex_dump << std::setfill('0') << std::hex;
+ for (const uint8_t byte : bytes) {
+ hex_dump << std::setw(2) << static_cast<int>(byte);
+ }
+ return hex_dump.str();
+}
+
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/stringprintf.h b/chromium/third_party/openscreen/src/util/stringprintf.h
index 93e5eb93677..df24ac042a5 100644
--- a/chromium/third_party/openscreen/src/util/stringprintf.h
+++ b/chromium/third_party/openscreen/src/util/stringprintf.h
@@ -5,7 +5,12 @@
#ifndef UTIL_STRINGPRINTF_H_
#define UTIL_STRINGPRINTF_H_
+#include <stdint.h>
+
#include <ostream>
+#include <string>
+
+#include "absl/types/span.h"
namespace openscreen {
@@ -36,6 +41,9 @@ void PrettyPrintAsciiHex(std::ostream& os, It first, It last) {
}
}
+// Returns a hex string representation of the given |bytes|.
+std::string HexEncode(absl::Span<const uint8_t> bytes);
+
} // namespace openscreen
#endif // UTIL_STRINGPRINTF_H_
diff --git a/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc b/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc
new file mode 100644
index 00000000000..df0d60f8b41
--- /dev/null
+++ b/chromium/third_party/openscreen/src/util/stringprintf_unittest.cc
@@ -0,0 +1,24 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/stringprintf.h"
+
+#include "gtest/gtest.h"
+
+namespace openscreen {
+namespace {
+
+TEST(HexEncode, ProducesEmptyStringFromEmptyByteArray) {
+ const uint8_t kSomeMemoryLocation = 0;
+ EXPECT_EQ("", HexEncode(absl::Span<const uint8_t>(&kSomeMemoryLocation, 0)));
+}
+
+TEST(HexEncode, ProducesHexStringsFromBytes) {
+ const uint8_t kMessage[] = "Hello world!";
+ const char kMessageInHex[] = "48656c6c6f20776f726c642100";
+ EXPECT_EQ(kMessageInHex, HexEncode(kMessage));
+}
+
+} // namespace
+} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/util/trace_logging.h b/chromium/third_party/openscreen/src/util/trace_logging.h
index cb454e1b296..527bfcb2ac1 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging.h
+++ b/chromium/third_party/openscreen/src/util/trace_logging.h
@@ -28,29 +28,31 @@
#include "util/trace_logging/macro_support.h"
#undef INCLUDING_FROM_UTIL_TRACE_LOGGING_H_
-#define TRACE_SET_RESULT(result) \
- do { \
- if (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any)) { \
- openscreen::internal::ScopedTraceOperation::set_result(result); \
- } \
+#define TRACE_SET_RESULT(result) \
+ do { \
+ if (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny)) { \
+ openscreen::internal::ScopedTraceOperation::set_result(result); \
+ } \
} while (false)
#define TRACE_SET_HIERARCHY(ids) TRACE_SET_HIERARCHY_INTERNAL(__LINE__, ids)
-#define TRACE_HIERARCHY \
- (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \
- ? openscreen::internal::ScopedTraceOperation::hierarchy() \
- : openscreen::platform::TraceIdHierarchy::Empty())
-#define TRACE_CURRENT_ID \
- (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \
- ? openscreen::internal::ScopedTraceOperation::current_id() \
+#define TRACE_HIERARCHY \
+ (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \
+ ? openscreen::internal::ScopedTraceOperation::hierarchy() \
+ : openscreen::TraceIdHierarchy::Empty())
+#define TRACE_CURRENT_ID \
+ (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \
+ ? openscreen::internal::ScopedTraceOperation::current_id() \
: kEmptyTraceId)
-#define TRACE_ROOT_ID \
- (TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \
- ? openscreen::internal::ScopedTraceOperation::root_id() \
+#define TRACE_ROOT_ID \
+ (TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \
+ ? openscreen::internal::ScopedTraceOperation::root_id() \
: kEmptyTraceId)
-// Synchronous Trace Macro.
+// Synchronous Trace Macros.
#define TRACE_SCOPED(category, name, ...) \
TRACE_SCOPED_INTERNAL(__LINE__, category, name, ##__VA_ARGS__)
+#define TRACE_DEFAULT_SCOPED(category, ...) \
+ TRACE_SCOPED(category, __PRETTY_FUNCTION__, ##__VA_ARGS__)
// Asynchronous Trace Macros.
#define TRACE_ASYNC_START(category, name, ...) \
@@ -76,9 +78,9 @@ inline void DoNothingForTracing(Args... args) {}
#define TRACE_SET_RESULT(result) \
openscreen::internal::DoNothingForTracing(result)
#define TRACE_SET_HIERARCHY(ids) openscreen::internal::DoNothingForTracing(ids)
-#define TRACE_HIERARCHY openscreen::platform::TraceIdHierarchy::Empty()
-#define TRACE_CURRENT_ID openscreen::platform::kEmptyTraceId
-#define TRACE_ROOT_ID openscreen::platform::kEmptyTraceId
+#define TRACE_HIERARCHY openscreen::TraceIdHierarchy::Empty()
+#define TRACE_CURRENT_ID openscreen::kEmptyTraceId
+#define TRACE_ROOT_ID openscreen::kEmptyTraceId
#define TRACE_SCOPED(category, name, ...) \
openscreen::internal::DoNothingForTracing(category, name, ##__VA_ARGS__)
#define TRACE_ASYNC_START(category, name, ...) \
diff --git a/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h b/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h
index 30865f4c9c9..d265081f428 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h
+++ b/chromium/third_party/openscreen/src/util/trace_logging/macro_support.h
@@ -42,8 +42,8 @@
namespace openscreen {
namespace internal {
-inline bool IsTraceLoggingEnabled(platform::TraceCategory::Value category) {
- auto* const destination = platform::GetTracingDestination();
+inline bool IsTraceLoggingEnabled(TraceCategory::Value category) {
+ const CurrentTracingDestination destination;
return destination && destination->IsTraceLoggingEnabled(category);
}
@@ -59,7 +59,7 @@ inline bool IsTraceLoggingEnabled(platform::TraceCategory::Value category) {
tracing_storage, line)[sizeof(openscreen::internal::TraceIdSetter)]; \
TRACE_INTERNAL_IGNORE_UNUSED_VAR \
const auto TRACE_INTERNAL_UNIQUE_VAR_NAME(trace_ref_) = \
- TRACE_IS_ENABLED(openscreen::platform::TraceCategory::Value::Any) \
+ TRACE_IS_ENABLED(openscreen::TraceCategory::Value::kAny) \
? openscreen::internal::TraceInstanceHelper< \
openscreen::internal::TraceIdSetter>:: \
Create(TRACE_INTERNAL_CONCAT_CONST(tracing_storage, line), \
diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc
index 7b19b74ddad..22c3f5b35c4 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc
+++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.cc
@@ -12,11 +12,6 @@
#if defined(ENABLE_TRACE_LOGGING)
-using openscreen::platform::kUnsetTraceId;
-using openscreen::platform::TraceCategory;
-using openscreen::platform::TraceId;
-using openscreen::platform::TraceIdHierarchy;
-
namespace openscreen {
namespace internal {
@@ -25,13 +20,13 @@ bool ScopedTraceOperation::TraceAsyncEnd(const uint32_t line,
const char* file,
TraceId id,
Error::Code e) {
- auto end_time = platform::Clock::now();
- auto* const current_platform = platform::GetTracingDestination();
- if (current_platform == nullptr) {
- return false;
+ auto end_time = Clock::now();
+ const CurrentTracingDestination destination;
+ if (destination) {
+ destination->LogAsyncEnd(line, file, end_time, id, e);
+ return true;
}
- current_platform->LogAsyncEnd(line, file, end_time, id, e);
- return true;
+ return false;
}
ScopedTraceOperation::ScopedTraceOperation(TraceId trace_id,
@@ -95,7 +90,7 @@ TraceLoggerBase::TraceLoggerBase(TraceCategory::Value category,
TraceId parent,
TraceId root)
: ScopedTraceOperation(current, parent, root),
- start_time_(platform::Clock::now()),
+ start_time_(Clock::now()),
result_(Error::Code::kNone),
name_(name),
file_name_(file),
@@ -116,24 +111,22 @@ TraceLoggerBase::TraceLoggerBase(TraceCategory::Value category,
ids.root) {}
SynchronousTraceLogger::~SynchronousTraceLogger() {
- auto* const current_platform = platform::GetTracingDestination();
- if (current_platform == nullptr) {
- return;
+ const CurrentTracingDestination destination;
+ if (destination) {
+ auto end_time = Clock::now();
+ destination->LogTrace(this->name_, this->line_number_, this->file_name_,
+ this->start_time_, end_time, this->to_hierarchy(),
+ this->result_);
}
- auto end_time = platform::Clock::now();
- current_platform->LogTrace(this->name_, this->line_number_, this->file_name_,
- this->start_time_, end_time, this->to_hierarchy(),
- this->result_);
}
AsynchronousTraceLogger::~AsynchronousTraceLogger() {
- auto* const current_platform = platform::GetTracingDestination();
- if (current_platform == nullptr) {
- return;
+ const CurrentTracingDestination destination;
+ if (destination) {
+ destination->LogAsyncStart(this->name_, this->line_number_,
+ this->file_name_, this->start_time_,
+ this->to_hierarchy());
}
- current_platform->LogAsyncStart(this->name_, this->line_number_,
- this->file_name_, this->start_time_,
- this->to_hierarchy());
}
TraceIdSetter::~TraceIdSetter() = default;
diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h
index f8af42da926..692fbf55fe9 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h
+++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations.h
@@ -7,7 +7,9 @@
#include <atomic>
#include <cstring>
+#include <memory>
#include <stack>
+#include <utility>
#include <vector>
#include "build/config/features.h"
@@ -34,19 +36,17 @@ class ScopedTraceOperation {
// Getters the current Trace Hierarchy. If the traces_ stack hasn't been
// created yet, return as if the empty root node is there.
- static platform::TraceId current_id() {
- return traces_ == nullptr ? platform::kEmptyTraceId
- : traces_->top()->trace_id_;
+ static TraceId current_id() {
+ return traces_ == nullptr ? kEmptyTraceId : traces_->top()->trace_id_;
}
- static platform::TraceId root_id() {
- return traces_ == nullptr ? platform::kEmptyTraceId
- : traces_->top()->root_id_;
+ static TraceId root_id() {
+ return traces_ == nullptr ? kEmptyTraceId : traces_->top()->root_id_;
}
- static platform::TraceIdHierarchy hierarchy() {
+ static TraceIdHierarchy hierarchy() {
if (traces_ == nullptr) {
- return platform::TraceIdHierarchy::Empty();
+ return TraceIdHierarchy::Empty();
}
return traces_->top()->to_hierarchy();
@@ -66,7 +66,7 @@ class ScopedTraceOperation {
// the ternary operator in the macros simpler.
static bool TraceAsyncEnd(const uint32_t line,
const char* file,
- platform::TraceId id,
+ TraceId id,
Error::Code e);
protected:
@@ -77,18 +77,16 @@ class ScopedTraceOperation {
virtual void SetTraceResult(Error::Code error) = 0;
// Constructor to set all trace id information.
- ScopedTraceOperation(platform::TraceId current_id = platform::kUnsetTraceId,
- platform::TraceId parent_id = platform::kUnsetTraceId,
- platform::TraceId root_id = platform::kUnsetTraceId);
+ ScopedTraceOperation(TraceId current_id = kUnsetTraceId,
+ TraceId parent_id = kUnsetTraceId,
+ TraceId root_id = kUnsetTraceId);
// Current TraceId information.
- platform::TraceId trace_id_;
- platform::TraceId parent_id_;
- platform::TraceId root_id_;
+ TraceId trace_id_;
+ TraceId parent_id_;
+ TraceId root_id_;
- platform::TraceIdHierarchy to_hierarchy() {
- return {trace_id_, parent_id_, root_id_};
- }
+ TraceIdHierarchy to_hierarchy() { return {trace_id_, parent_id_, root_id_}; }
private:
// NOTE: A std::vector is used for backing the stack because it provides the
@@ -113,26 +111,26 @@ class ScopedTraceOperation {
// The class which does actual trace logging.
class TraceLoggerBase : public ScopedTraceOperation {
public:
- TraceLoggerBase(platform::TraceCategory::Value category,
+ TraceLoggerBase(TraceCategory::Value category,
const char* name,
const char* file,
uint32_t line,
- platform::TraceId current = platform::kUnsetTraceId,
- platform::TraceId parent = platform::kUnsetTraceId,
- platform::TraceId root = platform::kUnsetTraceId);
+ TraceId current = kUnsetTraceId,
+ TraceId parent = kUnsetTraceId,
+ TraceId root = kUnsetTraceId);
- TraceLoggerBase(platform::TraceCategory::Value category,
+ TraceLoggerBase(TraceCategory::Value category,
const char* name,
const char* file,
uint32_t line,
- platform::TraceIdHierarchy ids);
+ TraceIdHierarchy ids);
protected:
// Set the result.
void SetTraceResult(Error::Code error) override { result_ = error; }
// Timestamp for when the object was created.
- platform::Clock::time_point start_time_;
+ Clock::time_point start_time_;
// Result of this operation.
Error::Code result_;
@@ -147,7 +145,7 @@ class TraceLoggerBase : public ScopedTraceOperation {
uint32_t line_number_;
// Category of this trace log.
- platform::TraceCategory::Value category_;
+ TraceCategory::Value category_;
private:
OSP_DISALLOW_COPY_AND_ASSIGN(TraceLoggerBase);
@@ -175,9 +173,9 @@ class AsynchronousTraceLogger : public TraceLoggerBase {
// Inserts a fake element into the ScopedTraceOperation stack to set
// the current TraceId Hierarchy manually.
-class TraceIdSetter : public ScopedTraceOperation {
+class TraceIdSetter final : public ScopedTraceOperation {
public:
- explicit TraceIdSetter(platform::TraceIdHierarchy ids)
+ explicit TraceIdSetter(TraceIdHierarchy ids)
: ScopedTraceOperation(ids.current, ids.parent, ids.root) {}
~TraceIdSetter() final;
diff --git a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc
index 63fa1cba5d3..7e52a6f48c3 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/trace_logging/scoped_trace_operations_unittest.cc
@@ -21,12 +21,9 @@ using ::testing::_;
using ::testing::DoAll;
using ::testing::Invoke;
-using platform::kEmptyTraceId;
-using platform::MockLoggingPlatform;
-
// These tests validate that parameters are passed correctly by using the Trace
// Internals.
-constexpr auto category = platform::TraceCategory::mDNS;
+constexpr auto category = TraceCategory::kMdns;
constexpr uint32_t line = 10;
TEST(TraceLoggingInternalTest, CreatingNoTraceObjectValid) {
@@ -38,15 +35,13 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationTrue) {
MockLoggingPlatform platform;
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.Times(1)
- .WillOnce(
- DoAll(Invoke(platform::ValidateTraceTimestampDiff<delay_in_ms>),
- Invoke(platform::ValidateTraceErrorCode<Error::Code::kNone>)));
+ .WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>),
+ Invoke(ValidateTraceErrorCode<Error::Code::kNone>)));
{
uint8_t temp[sizeof(SynchronousTraceLogger)];
- auto ptr = true ? TraceInstanceHelper<SynchronousTraceLogger>::Create(
- temp, category, "Name", __FILE__, line)
- : TraceInstanceHelper<SynchronousTraceLogger>::Empty();
+ auto ptr = TraceInstanceHelper<SynchronousTraceLogger>::Create(
+ temp, category, "Name", __FILE__, line);
std::this_thread::sleep_for(std::chrono::milliseconds(delay_in_ms));
auto ids = ScopedTraceOperation::hierarchy();
EXPECT_NE(ids.current, kEmptyTraceId);
@@ -62,10 +57,7 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationFalse) {
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(0);
{
- uint8_t temp[sizeof(SynchronousTraceLogger)];
- auto ptr = false ? TraceInstanceHelper<SynchronousTraceLogger>::Create(
- temp, category, "Name", __FILE__, line)
- : TraceInstanceHelper<SynchronousTraceLogger>::Empty();
+ auto ptr = TraceInstanceHelper<SynchronousTraceLogger>::Empty();
auto ids = ScopedTraceOperation::hierarchy();
EXPECT_EQ(ids.current, kEmptyTraceId);
EXPECT_EQ(ids.parent, kEmptyTraceId);
@@ -81,7 +73,7 @@ TEST(TraceLoggingInternalTest, ExpectParametersPassedToResult) {
MockLoggingPlatform platform;
EXPECT_CALL(platform, LogTrace(testing::StrEq("Name"), line,
testing::StrEq(__FILE__), _, _, _, _))
- .WillOnce(Invoke(platform::ValidateTraceErrorCode<Error::Code::kNone>));
+ .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>));
{ SynchronousTraceLogger{category, "Name", __FILE__, line}; }
}
diff --git a/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc b/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc
index d34ed303f93..386b19e6424 100644
--- a/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/trace_logging_unittest.cc
@@ -13,7 +13,6 @@
#include "platform/test/trace_logging_helpers.h"
namespace openscreen {
-namespace platform {
namespace {
#if defined(ENABLE_TRACE_LOGGING)
@@ -38,37 +37,48 @@ using StrictMockLoggingPlatform = ::testing::StrictMock<MockLoggingPlatform>;
TEST(TraceLoggingTest, MacroCallScopedDoesntSegFault) {
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1);
#endif
- { TRACE_SCOPED(TraceCategory::Value::Any, "test"); }
+ { TRACE_SCOPED(TraceCategory::Value::kAny, "test"); }
+}
+
+TEST(TraceLoggingTest, MacroCallDefaultScopedDoesntSegFault) {
+ StrictMockLoggingPlatform platform;
+#if defined(ENABLE_TRACE_LOGGING)
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
+ .Times(AtLeast(1));
+ EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1);
+#endif
+ { TRACE_DEFAULT_SCOPED(TraceCategory::Value::kAny); }
}
TEST(TraceLoggingTest, MacroCallUnscopedDoesntSegFault) {
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1);
#endif
- { TRACE_ASYNC_START(TraceCategory::Value::Any, "test"); }
+ { TRACE_ASYNC_START(TraceCategory::Value::kAny, "test"); }
}
TEST(TraceLoggingTest, MacroVariablesUniquelyNames) {
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
- EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(2);
+ EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(3);
EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(2);
#endif
{
- TRACE_SCOPED(TraceCategory::Value::Any, "test1");
- TRACE_SCOPED(TraceCategory::Value::Any, "test2");
- TRACE_ASYNC_START(TraceCategory::Value::Any, "test3");
- TRACE_ASYNC_START(TraceCategory::Value::Any, "test4");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "test1");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "test2");
+ TRACE_ASYNC_START(TraceCategory::Value::kAny, "test3");
+ TRACE_ASYNC_START(TraceCategory::Value::kAny, "test4");
+ TRACE_DEFAULT_SCOPED(TraceCategory::Value::kAny);
}
}
@@ -76,7 +86,7 @@ TEST(TraceLoggingTest, ExpectTimestampsReflectDelay) {
constexpr uint32_t delay_in_ms = 50;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>),
@@ -84,7 +94,7 @@ TEST(TraceLoggingTest, ExpectTimestampsReflectDelay) {
#endif
{
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
std::this_thread::sleep_for(std::chrono::milliseconds(delay_in_ms));
}
}
@@ -93,14 +103,14 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) {
constexpr Error::Code result_code = Error::Code::kParseError;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(Invoke(ValidateTraceErrorCode<result_code>));
#endif
{
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
TRACE_SET_RESULT(result_code);
}
}
@@ -108,14 +118,14 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) {
TEST(TraceLoggingTest, ExpectUnsetTraceIdNotSet) {
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1);
#endif
TraceIdHierarchy h = {kUnsetTraceId, kUnsetTraceId, kUnsetTraceId};
{
- TRACE_SCOPED(TraceCategory::Value::Any, "Name", h);
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h);
auto ids = TRACE_HIERARCHY;
EXPECT_NE(ids.current, kUnsetTraceId);
@@ -130,7 +140,7 @@ TEST(TraceLoggingTest, ExpectCreationWithIdsToWork) {
constexpr TraceId root = 0x84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(
@@ -141,7 +151,7 @@ TEST(TraceLoggingTest, ExpectCreationWithIdsToWork) {
{
TraceIdHierarchy h = {current, parent, root};
- TRACE_SCOPED(TraceCategory::Value::Any, "Name", h);
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h);
#if defined(ENABLE_TRACE_LOGGING)
auto ids = TRACE_HIERARCHY;
@@ -161,7 +171,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) {
constexpr TraceId root = 0x84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(DoAll(
@@ -176,7 +186,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) {
{
TraceIdHierarchy h = {current, parent, root};
- TRACE_SCOPED(TraceCategory::Value::Any, "Name", h);
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name", h);
#if defined(ENABLE_TRACE_LOGGING)
auto ids = TRACE_HIERARCHY;
EXPECT_EQ(ids.current, current);
@@ -184,7 +194,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) {
EXPECT_EQ(ids.root, root);
#endif
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
#if defined(ENABLE_TRACE_LOGGING)
ids = TRACE_HIERARCHY;
EXPECT_NE(ids.current, current);
@@ -200,7 +210,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) {
constexpr TraceId root = 0x84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(DoAll(
@@ -213,7 +223,7 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) {
TraceIdHierarchy ids = {current, parent, root};
TRACE_SET_HIERARCHY(ids);
{
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
ids = TRACE_HIERARCHY;
#if defined(ENABLE_TRACE_LOGGING)
EXPECT_NE(ids.current, current);
@@ -231,13 +241,13 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) {
}
}
-TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScope) {
+TEST(TraceLoggingTest, ExpectHierarchyToEndAfterScope) {
constexpr TraceId current = 0x32;
constexpr TraceId parent = 0x47;
constexpr TraceId root = 0x84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(DoAll(
@@ -252,9 +262,9 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScope) {
{
TraceIdHierarchy ids = {current, parent, root};
- TRACE_SCOPED(TraceCategory::Value::Any, "Name", ids);
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name", ids);
{
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
ids = TRACE_HIERARCHY;
#if defined(ENABLE_TRACE_LOGGING)
EXPECT_NE(ids.current, current);
@@ -278,7 +288,7 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) {
constexpr TraceId root = 0x84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _))
.WillOnce(DoAll(
@@ -297,7 +307,7 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) {
EXPECT_EQ(ids.root, root);
#endif
- TRACE_SCOPED(TraceCategory::Value::Any, "Name");
+ TRACE_SCOPED(TraceCategory::Value::kAny, "Name");
ids = TRACE_HIERARCHY;
#if defined(ENABLE_TRACE_LOGGING)
EXPECT_NE(ids.current, current);
@@ -310,12 +320,12 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) {
TEST(TraceLoggingTest, CheckTraceAsyncStartLogsCorrectly) {
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1);
#endif
- { TRACE_ASYNC_START(TraceCategory::Value::Any, "Name"); }
+ { TRACE_ASYNC_START(TraceCategory::Value::kAny, "Name"); }
}
TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) {
@@ -324,7 +334,7 @@ TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) {
constexpr TraceId root = 84;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _))
.WillOnce(
@@ -336,7 +346,7 @@ TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) {
TraceIdHierarchy ids = {current, parent, root};
TRACE_SET_HIERARCHY(ids);
{
- TRACE_ASYNC_START(TraceCategory::Value::Any, "Name");
+ TRACE_ASYNC_START(TraceCategory::Value::kAny, "Name");
ids = TRACE_HIERARCHY;
#if defined(ENABLE_TRACE_LOGGING)
EXPECT_NE(ids.current, current);
@@ -359,14 +369,13 @@ TEST(TraceLoggingTest, CheckTraceAsyncEndLogsCorrectly) {
constexpr Error::Code result = Error::Code::kAgain;
StrictMockLoggingPlatform platform;
#if defined(ENABLE_TRACE_LOGGING)
- EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::Any))
+ EXPECT_CALL(platform, IsTraceLoggingEnabled(TraceCategory::Value::kAny))
.Times(AtLeast(1));
EXPECT_CALL(platform, LogAsyncEnd(_, _, _, id, result)).Times(1);
#endif
- TRACE_ASYNC_END(TraceCategory::Value::Any, id, result);
+ TRACE_ASYNC_END(TraceCategory::Value::kAny, id, result);
}
} // namespace
-} // namespace platform
} // namespace openscreen
diff --git a/chromium/third_party/openscreen/src/platform/impl/weak_ptr.h b/chromium/third_party/openscreen/src/util/weak_ptr.h
index 173c57e8f48..93b97e58c42 100644
--- a/chromium/third_party/openscreen/src/platform/impl/weak_ptr.h
+++ b/chromium/third_party/openscreen/src/util/weak_ptr.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef PLATFORM_IMPL_WEAK_PTR_H_
-#define PLATFORM_IMPL_WEAK_PTR_H_
+#ifndef UTIL_WEAK_PTR_H_
+#define UTIL_WEAK_PTR_H_
#include <memory>
@@ -213,4 +213,4 @@ class WeakPtrFactory {
} // namespace openscreen
-#endif // PLATFORM_IMPL_WEAK_PTR_H_
+#endif // UTIL_WEAK_PTR_H_
diff --git a/chromium/third_party/openscreen/src/platform/impl/weak_ptr_unittest.cc b/chromium/third_party/openscreen/src/util/weak_ptr_unittest.cc
index b06574bb1c5..328802695d7 100644
--- a/chromium/third_party/openscreen/src/platform/impl/weak_ptr_unittest.cc
+++ b/chromium/third_party/openscreen/src/util/weak_ptr_unittest.cc
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "platform/impl/weak_ptr.h"
+#include "util/weak_ptr.h"
#include "gtest/gtest.h"
@@ -15,7 +15,7 @@ class SomeClass {
virtual int GetValue() const { return 42; }
};
-struct SomeSubclass : public SomeClass {
+struct SomeSubclass final : public SomeClass {
public:
~SomeSubclass() final = default;
int GetValue() const override { return 999; }