diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-12-10 16:19:40 +0100 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-12-10 16:01:50 +0000 |
commit | 51f6c2793adab2d864b3d2b360000ef8db1d3e92 (patch) | |
tree | 835b3b4446b012c75e80177cef9fbe6972cc7dbe /chromium/net/dns | |
parent | 6036726eb981b6c4b42047513b9d3f4ac865daac (diff) | |
download | qtwebengine-chromium-51f6c2793adab2d864b3d2b360000ef8db1d3e92.tar.gz |
BASELINE: Update Chromium to 71.0.3578.93
Change-Id: I6a32086c33670e1b033f8b10e6bf1fd4da1d105d
Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/net/dns')
53 files changed, 2321 insertions, 366 deletions
diff --git a/chromium/net/dns/BUILD.gn b/chromium/net/dns/BUILD.gn index d0410ac9849..d70a30f4ff9 100644 --- a/chromium/net/dns/BUILD.gn +++ b/chromium/net/dns/BUILD.gn @@ -31,7 +31,10 @@ source_set("dns") { sources += [ "address_sorter.h", "address_sorter_win.cc", + "dns_config.cc", + "dns_config_overrides.cc", "dns_config_service.cc", + "dns_config_service.h", "dns_config_service_win.cc", "dns_config_service_win.h", "dns_config_watcher_mac.cc", @@ -51,6 +54,8 @@ source_set("dns") { "host_cache.cc", "host_resolver.cc", "host_resolver_impl.cc", + "host_resolver_mdns_task.cc", + "host_resolver_mdns_task.h", "host_resolver_proc.cc", "host_resolver_proc.h", "host_resolver_source.h", @@ -161,11 +166,6 @@ source_set("host_resolver") { # TODO(crbug.com/874654): Remove once migrated to network service IPC. "//components/network_hints/browser", - # content/public/browser/resource_hints.h - # Deprecated and soon to be removed. - # TODO(crbug.com/875238): Remove once code is removed. - "//content/public/browser:browser_sources", - # headless/lib/browser/headless_url_request_context_getter.cc # URLRequestContext creation for headless. "//headless", @@ -197,7 +197,8 @@ source_set("host_resolver") { if (!is_nacl) { sources += [ - "dns_config_service.h", + "dns_config.h", + "dns_config_overrides.h", "host_cache.h", "host_resolver.h", "mapped_host_resolver.h", @@ -353,7 +354,6 @@ if (enable_net_mojo) { deps = [ "//base", "//net", - "//net:net_with_v8", ] public_deps = [ @@ -458,8 +458,14 @@ source_set("test_support") { ] if (enable_mdns) { - sources += [ "mock_mdns_socket_factory.cc" ] - public += [ "mock_mdns_socket_factory.h" ] + sources += [ + "mock_mdns_client.cc", + "mock_mdns_socket_factory.cc", + ] + public += [ + "mock_mdns_client.h", + "mock_mdns_socket_factory.h", + ] } deps = [ @@ -507,6 +513,30 @@ fuzzer_test("net_dns_record_fuzzer") { dict = "//net/data/fuzzer_dictionaries/net_dns_record_fuzzer.dict" } +fuzzer_test("net_dns_query_parse_fuzzer") { + sources = [ + "dns_query_parse_fuzzer.cc", + ] + deps = [ + "//base", + "//net", + "//net:net_fuzzer_test_support", + ] + dict = "//net/data/fuzzer_dictionaries/net_dns_record_fuzzer.dict" +} + +fuzzer_test("net_dns_response_fuzzer") { + sources = [ + "dns_response_fuzzer.cc", + ] + deps = [ + "//base", + "//net", + "//net:net_fuzzer_test_support", + ] + dict = "//net/data/fuzzer_dictionaries/net_dns_record_fuzzer.dict" +} + fuzzer_test("net_host_resolver_impl_fuzzer") { sources = [ "host_resolver_impl_fuzzer.cc", diff --git a/chromium/net/dns/address_sorter_posix_unittest.cc b/chromium/net/dns/address_sorter_posix_unittest.cc index da5ba524d15..123075670cb 100644 --- a/chromium/net/dns/address_sorter_posix_unittest.cc +++ b/chromium/net/dns/address_sorter_posix_unittest.cc @@ -114,7 +114,7 @@ class TestUDPClientSocket : public DatagramClientSocket { int Connect(const IPEndPoint& remote) override { if (connected_) return ERR_UNEXPECTED; - AddressMapping::const_iterator it = mapping_->find(remote.address()); + auto it = mapping_->find(remote.address()); if (it == mapping_->end()) return ERR_FAILED; connected_ = true; diff --git a/chromium/net/dns/dns_client.cc b/chromium/net/dns/dns_client.cc index 6158b0a07ac..c230b32cd64 100644 --- a/chromium/net/dns/dns_client.cc +++ b/chromium/net/dns/dns_client.cc @@ -9,7 +9,7 @@ #include "base/bind.h" #include "base/rand_util.h" #include "net/dns/address_sorter.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_session.h" #include "net/dns/dns_socket_pool.h" #include "net/dns/dns_transaction.h" diff --git a/chromium/net/dns/dns_config.cc b/chromium/net/dns/dns_config.cc new file mode 100644 index 00000000000..4995467167b --- /dev/null +++ b/chromium/net/dns/dns_config.cc @@ -0,0 +1,101 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config.h" + +#include <utility> + +#include "base/values.h" + +namespace net { + +// Default values are taken from glibc resolv.h except timeout which is set to +// |kDnsDefaultTimeoutMs|. +DnsConfig::DnsConfig() + : unhandled_options(false), + append_to_multi_label_name(true), + randomize_ports(false), + ndots(1), + timeout(kDnsDefaultTimeout), + attempts(2), + rotate(false), + use_local_ipv6(false) {} + +DnsConfig::DnsConfig(const DnsConfig& other) = default; + +DnsConfig::~DnsConfig() = default; + +bool DnsConfig::Equals(const DnsConfig& d) const { + return EqualsIgnoreHosts(d) && (hosts == d.hosts); +} + +bool DnsConfig::EqualsIgnoreHosts(const DnsConfig& d) const { + return (nameservers == d.nameservers) && (search == d.search) && + (unhandled_options == d.unhandled_options) && + (append_to_multi_label_name == d.append_to_multi_label_name) && + (ndots == d.ndots) && (timeout == d.timeout) && + (attempts == d.attempts) && (rotate == d.rotate) && + (use_local_ipv6 == d.use_local_ipv6) && + (dns_over_https_servers == d.dns_over_https_servers); +} + +void DnsConfig::CopyIgnoreHosts(const DnsConfig& d) { + nameservers = d.nameservers; + search = d.search; + unhandled_options = d.unhandled_options; + append_to_multi_label_name = d.append_to_multi_label_name; + ndots = d.ndots; + timeout = d.timeout; + attempts = d.attempts; + rotate = d.rotate; + use_local_ipv6 = d.use_local_ipv6; + dns_over_https_servers = d.dns_over_https_servers; +} + +std::unique_ptr<base::Value> DnsConfig::ToValue() const { + auto dict = std::make_unique<base::DictionaryValue>(); + + auto list = std::make_unique<base::ListValue>(); + for (size_t i = 0; i < nameservers.size(); ++i) + list->AppendString(nameservers[i].ToString()); + dict->Set("nameservers", std::move(list)); + + list = std::make_unique<base::ListValue>(); + for (size_t i = 0; i < search.size(); ++i) + list->AppendString(search[i]); + dict->Set("search", std::move(list)); + + dict->SetBoolean("unhandled_options", unhandled_options); + dict->SetBoolean("append_to_multi_label_name", append_to_multi_label_name); + dict->SetInteger("ndots", ndots); + dict->SetDouble("timeout", timeout.InSecondsF()); + dict->SetInteger("attempts", attempts); + dict->SetBoolean("rotate", rotate); + dict->SetBoolean("use_local_ipv6", use_local_ipv6); + dict->SetInteger("num_hosts", hosts.size()); + list = std::make_unique<base::ListValue>(); + for (auto& server : dns_over_https_servers) { + base::Value val(base::Value::Type::DICTIONARY); + base::DictionaryValue* dict; + val.GetAsDictionary(&dict); + dict->SetString("server_template", server.server_template); + dict->SetBoolean("use_post", server.use_post); + list->GetList().push_back(std::move(val)); + } + dict->Set("doh_servers", std::move(list)); + + return std::move(dict); +} + +DnsConfig::DnsOverHttpsServerConfig::DnsOverHttpsServerConfig( + const std::string& server_template, + bool use_post) + : server_template(server_template), use_post(use_post) {} + +bool DnsConfig::DnsOverHttpsServerConfig::operator==( + const DnsOverHttpsServerConfig& other) const { + return server_template == other.server_template && use_post == other.use_post; +} + +} // namespace net diff --git a/chromium/net/dns/dns_config.h b/chromium/net/dns/dns_config.h new file mode 100644 index 00000000000..225f9949cbd --- /dev/null +++ b/chromium/net/dns/dns_config.h @@ -0,0 +1,96 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CONFIG_H_ +#define NET_DNS_DNS_CONFIG_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "base/time/time.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/dns/dns_hosts.h" + +namespace base { +class Value; +} + +namespace net { + +// Default to 1 second timeout (before exponential backoff). +constexpr base::TimeDelta kDnsDefaultTimeout = base::TimeDelta::FromSeconds(1); + +// DnsConfig stores configuration of the system resolver. +struct NET_EXPORT DnsConfig { + DnsConfig(); + DnsConfig(const DnsConfig& other); + ~DnsConfig(); + + bool Equals(const DnsConfig& d) const; + + bool EqualsIgnoreHosts(const DnsConfig& d) const; + + void CopyIgnoreHosts(const DnsConfig& src); + + // Returns a Value representation of |this|. For performance reasons, the + // Value only contains the number of hosts rather than the full list. + std::unique_ptr<base::Value> ToValue() const; + + bool IsValid() const { return !nameservers.empty(); } + + struct NET_EXPORT DnsOverHttpsServerConfig { + DnsOverHttpsServerConfig(const std::string& server_template, bool use_post); + + bool operator==(const DnsOverHttpsServerConfig& other) const; + + std::string server_template; + bool use_post; + }; + + // List of name server addresses. + std::vector<IPEndPoint> nameservers; + // Suffix search list; used on first lookup when number of dots in given name + // is less than |ndots|. + std::vector<std::string> search; + + DnsHosts hosts; + + // True if there are options set in the system configuration that are not yet + // supported by DnsClient. + bool unhandled_options; + + // AppendToMultiLabelName: is suffix search performed for multi-label names? + // True, except on Windows where it can be configured. + bool append_to_multi_label_name; + + // Indicates that source port randomization is required. This uses additional + // resources on some platforms. + bool randomize_ports; + + // Resolver options; see man resolv.conf. + + // Minimum number of dots before global resolution precedes |search|. + int ndots; + // Time between retransmissions, see res_state.retrans. + base::TimeDelta timeout; + // Maximum number of attempts, see res_state.retry. + int attempts; + // Round robin entries in |nameservers| for subsequent requests. + bool rotate; + + // Indicates system configuration uses local IPv6 connectivity, e.g., + // DirectAccess. This is exposed for HostResolver to skip IPv6 probes, + // as it may cause them to return incorrect results. + bool use_local_ipv6; + + // List of servers to query over HTTPS, queried in order + // (https://tools.ietf.org/id/draft-ietf-doh-dns-over-https-12.txt). + std::vector<DnsOverHttpsServerConfig> dns_over_https_servers; +}; + +} // namespace net + +#endif // NET_DNS_DNS_CONFIG_H_ diff --git a/chromium/net/dns/dns_config_overrides.cc b/chromium/net/dns/dns_config_overrides.cc new file mode 100644 index 00000000000..ae0e0545b83 --- /dev/null +++ b/chromium/net/dns/dns_config_overrides.cc @@ -0,0 +1,58 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_overrides.h" + +namespace net { + +DnsConfigOverrides::DnsConfigOverrides() = default; + +DnsConfigOverrides::DnsConfigOverrides(const DnsConfigOverrides& other) = + default; + +DnsConfigOverrides::~DnsConfigOverrides() = default; + +DnsConfigOverrides& DnsConfigOverrides::operator=( + const DnsConfigOverrides& other) = default; + +bool DnsConfigOverrides::operator==(const DnsConfigOverrides& other) const { + return nameservers == other.nameservers && search == other.search && + hosts == other.hosts && + append_to_multi_label_name == other.append_to_multi_label_name && + randomize_ports == other.randomize_ports && ndots == other.ndots && + timeout == other.timeout && attempts == other.attempts && + rotate == other.rotate && use_local_ipv6 == other.use_local_ipv6 && + dns_over_https_servers == other.dns_over_https_servers; +} + +DnsConfig DnsConfigOverrides::ApplyOverrides(const DnsConfig& config) const { + DnsConfig overridden(config); + + if (nameservers) + overridden.nameservers = nameservers.value(); + if (search) + overridden.search = search.value(); + if (hosts) + overridden.hosts = hosts.value(); + if (append_to_multi_label_name) + overridden.append_to_multi_label_name = append_to_multi_label_name.value(); + if (randomize_ports) + overridden.randomize_ports = randomize_ports.value(); + if (ndots) + overridden.ndots = ndots.value(); + if (timeout) + overridden.timeout = timeout.value(); + if (attempts) + overridden.attempts = attempts.value(); + if (rotate) + overridden.rotate = rotate.value(); + if (use_local_ipv6) + overridden.use_local_ipv6 = use_local_ipv6.value(); + if (dns_over_https_servers) + overridden.dns_over_https_servers = dns_over_https_servers.value(); + + return overridden; +} + +} // namespace net diff --git a/chromium/net/dns/dns_config_overrides.h b/chromium/net/dns/dns_config_overrides.h new file mode 100644 index 00000000000..6ef16085bcd --- /dev/null +++ b/chromium/net/dns/dns_config_overrides.h @@ -0,0 +1,55 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CONFIG_OVERRIDES_H_ +#define NET_DNS_DNS_CONFIG_OVERRIDES_H_ + +#include <string> +#include <vector> + +#include "base/optional.h" +#include "base/time/time.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/dns/dns_config.h" +#include "net/dns/dns_hosts.h" + +namespace net { + +// Overriding values to be applied over a DnsConfig struct. +struct NET_EXPORT DnsConfigOverrides { + DnsConfigOverrides(); + DnsConfigOverrides(const DnsConfigOverrides& other); + ~DnsConfigOverrides(); + + DnsConfigOverrides& operator=(const DnsConfigOverrides& other); + + bool operator==(const DnsConfigOverrides& other) const; + + // Creates a new DnsConfig where any field with an overriding value in |this| + // is replaced with that overriding value. Any field without an overriding + // value (|base::nullopt|) will be copied as-is from |config|. + DnsConfig ApplyOverrides(const DnsConfig& config) const; + + // Overriding values. See same-named fields in DnsConfig for explanations. + base::Optional<std::vector<IPEndPoint>> nameservers; + base::Optional<std::vector<std::string>> search; + base::Optional<DnsHosts> hosts; + base::Optional<bool> append_to_multi_label_name; + base::Optional<bool> randomize_ports; + base::Optional<int> ndots; + base::Optional<base::TimeDelta> timeout; + base::Optional<int> attempts; + base::Optional<bool> rotate; + base::Optional<bool> use_local_ipv6; + base::Optional<std::vector<DnsConfig::DnsOverHttpsServerConfig>> + dns_over_https_servers; + + // Note no overriding value for |unhandled_options|. It is meta-configuration, + // and there should be no reason to override it. +}; + +} // namespace net + +#endif // NET_DNS_DNS_CONFIG_OVERRIDES_H_ diff --git a/chromium/net/dns/dns_config_service.cc b/chromium/net/dns/dns_config_service.cc index 6a291e99f41..96731cde39d 100644 --- a/chromium/net/dns/dns_config_service.cc +++ b/chromium/net/dns/dns_config_service.cc @@ -4,100 +4,13 @@ #include "net/dns/dns_config_service.h" -#include <utility> +#include <string> #include "base/logging.h" #include "base/metrics/histogram_macros.h" -#include "base/values.h" -#include "net/base/ip_endpoint.h" -#include "net/base/ip_pattern.h" namespace net { -// Default values are taken from glibc resolv.h except timeout which is set to -// |kDnsDefaultTimeoutMs|. -DnsConfig::DnsConfig() - : unhandled_options(false), - append_to_multi_label_name(true), - randomize_ports(false), - ndots(1), - timeout(base::TimeDelta::FromMilliseconds(kDnsDefaultTimeoutMs)), - attempts(2), - rotate(false), - use_local_ipv6(false) {} - -DnsConfig::DnsConfig(const DnsConfig& other) = default; - -DnsConfig::~DnsConfig() = default; - -bool DnsConfig::Equals(const DnsConfig& d) const { - return EqualsIgnoreHosts(d) && (hosts == d.hosts); -} - -bool DnsConfig::EqualsIgnoreHosts(const DnsConfig& d) const { - return (nameservers == d.nameservers) && - (search == d.search) && - (unhandled_options == d.unhandled_options) && - (append_to_multi_label_name == d.append_to_multi_label_name) && - (ndots == d.ndots) && - (timeout == d.timeout) && - (attempts == d.attempts) && - (rotate == d.rotate) && - (use_local_ipv6 == d.use_local_ipv6); -} - -void DnsConfig::CopyIgnoreHosts(const DnsConfig& d) { - nameservers = d.nameservers; - search = d.search; - unhandled_options = d.unhandled_options; - append_to_multi_label_name = d.append_to_multi_label_name; - ndots = d.ndots; - timeout = d.timeout; - attempts = d.attempts; - rotate = d.rotate; - use_local_ipv6 = d.use_local_ipv6; -} - -std::unique_ptr<base::Value> DnsConfig::ToValue() const { - auto dict = std::make_unique<base::DictionaryValue>(); - - auto list = std::make_unique<base::ListValue>(); - for (size_t i = 0; i < nameservers.size(); ++i) - list->AppendString(nameservers[i].ToString()); - dict->Set("nameservers", std::move(list)); - - list = std::make_unique<base::ListValue>(); - for (size_t i = 0; i < search.size(); ++i) - list->AppendString(search[i]); - dict->Set("search", std::move(list)); - - dict->SetBoolean("unhandled_options", unhandled_options); - dict->SetBoolean("append_to_multi_label_name", append_to_multi_label_name); - dict->SetInteger("ndots", ndots); - dict->SetDouble("timeout", timeout.InSecondsF()); - dict->SetInteger("attempts", attempts); - dict->SetBoolean("rotate", rotate); - dict->SetBoolean("use_local_ipv6", use_local_ipv6); - dict->SetInteger("num_hosts", hosts.size()); - list = std::make_unique<base::ListValue>(); - for (auto& server : dns_over_https_servers) { - base::Value val(base::Value::Type::DICTIONARY); - base::DictionaryValue* dict; - val.GetAsDictionary(&dict); - dict->SetString("server_template", server.server_template); - dict->SetBoolean("use_post", server.use_post); - list->GetList().push_back(std::move(val)); - } - dict->Set("doh_servers", std::move(list)); - - return std::move(dict); -} - -DnsConfig::DnsOverHttpsServerConfig::DnsOverHttpsServerConfig( - const std::string& server_template, - bool use_post) - : server_template(server_template), use_post(use_post) {} - DnsConfigService::DnsConfigService() : watch_failed_(false), have_config_(false), diff --git a/chromium/net/dns/dns_config_service.h b/chromium/net/dns/dns_config_service.h index 9b1cb0b7a22..88463146169 100644 --- a/chromium/net/dns/dns_config_service.h +++ b/chromium/net/dns/dns_config_service.h @@ -7,98 +7,18 @@ #include <map> #include <memory> -#include <string> -#include <vector> #include "base/macros.h" #include "base/threading/thread_checker.h" #include "base/time/time.h" #include "base/timer/timer.h" -// Needed on shared build with MSVS2010 to avoid multiple definitions of -// std::vector<IPEndPoint>. -#include "net/base/address_list.h" -#include "net/base/ip_endpoint.h" // win requires size of IPEndPoint #include "net/base/net_export.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_hosts.h" #include "url/gurl.h" -namespace base { -class Value; -} - namespace net { -// Default to 1 second timeout (before exponential backoff). -const int64_t kDnsDefaultTimeoutMs = 1000; - -// DnsConfig stores configuration of the system resolver. -struct NET_EXPORT DnsConfig { - DnsConfig(); - DnsConfig(const DnsConfig& other); - ~DnsConfig(); - - bool Equals(const DnsConfig& d) const; - - bool EqualsIgnoreHosts(const DnsConfig& d) const; - - void CopyIgnoreHosts(const DnsConfig& src); - - // Returns a Value representation of |this|. For performance reasons, the - // Value only contains the number of hosts rather than the full list. - std::unique_ptr<base::Value> ToValue() const; - - bool IsValid() const { - return !nameservers.empty(); - } - - struct NET_EXPORT DnsOverHttpsServerConfig { - DnsOverHttpsServerConfig(const std::string& server_template, bool use_post); - - std::string server_template; - bool use_post; - }; - - // List of name server addresses. - std::vector<IPEndPoint> nameservers; - // Suffix search list; used on first lookup when number of dots in given name - // is less than |ndots|. - std::vector<std::string> search; - - DnsHosts hosts; - - // True if there are options set in the system configuration that are not yet - // supported by DnsClient. - bool unhandled_options; - - // AppendToMultiLabelName: is suffix search performed for multi-label names? - // True, except on Windows where it can be configured. - bool append_to_multi_label_name; - - // Indicates that source port randomization is required. This uses additional - // resources on some platforms. - bool randomize_ports; - - // Resolver options; see man resolv.conf. - - // Minimum number of dots before global resolution precedes |search|. - int ndots; - // Time between retransmissions, see res_state.retrans. - base::TimeDelta timeout; - // Maximum number of attempts, see res_state.retry. - int attempts; - // Round robin entries in |nameservers| for subsequent requests. - bool rotate; - - // Indicates system configuration uses local IPv6 connectivity, e.g., - // DirectAccess. This is exposed for HostResolver to skip IPv6 probes, - // as it may cause them to return incorrect results. - bool use_local_ipv6; - - // List of servers to query over HTTPS, queried in order - // (https://tools.ietf.org/id/draft-ietf-doh-dns-over-https-12.txt). - std::vector<DnsOverHttpsServerConfig> dns_over_https_servers; -}; - // Service for reading system DNS settings, on demand or when signalled by // internal watchers and NetworkChangeNotifier. class NET_EXPORT_PRIVATE DnsConfigService { diff --git a/chromium/net/dns/dns_config_service_posix.cc b/chromium/net/dns/dns_config_service_posix.cc index 06e9d7969c0..1b947c5b635 100644 --- a/chromium/net/dns/dns_config_service_posix.cc +++ b/chromium/net/dns/dns_config_service_posix.cc @@ -22,6 +22,7 @@ #include "build/build_config.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_hosts.h" #include "net/dns/dns_protocol.h" #include "net/dns/notify_watcher_mac.h" @@ -186,7 +187,7 @@ ConfigParsePosixResult ReadDnsConfig(DnsConfig* dns_config) { } #endif // defined(OS_MACOSX) && !defined(OS_IOS) // Override timeout value to match default setting on Windows. - dns_config->timeout = base::TimeDelta::FromMilliseconds(kDnsDefaultTimeoutMs); + dns_config->timeout = kDnsDefaultTimeout; return result; #else // defined(OS_ANDROID) dns_config->nameservers.clear(); diff --git a/chromium/net/dns/dns_config_service_posix.h b/chromium/net/dns/dns_config_service_posix.h index d27c65c98a5..81888cf8305 100644 --- a/chromium/net/dns/dns_config_service_posix.h +++ b/chromium/net/dns/dns_config_service_posix.h @@ -19,6 +19,7 @@ #include "net/dns/dns_config_service.h" namespace net { +struct DnsConfig; // Use DnsConfigService::CreateSystemService to use it outside of tests. namespace internal { diff --git a/chromium/net/dns/dns_config_service_posix_unittest.cc b/chromium/net/dns/dns_config_service_posix_unittest.cc index 241663f0543..82af9052f7c 100644 --- a/chromium/net/dns/dns_config_service_posix_unittest.cc +++ b/chromium/net/dns/dns_config_service_posix_unittest.cc @@ -14,6 +14,7 @@ #include "base/test/test_timeouts.h" #include "base/threading/platform_thread.h" #include "net/base/ip_address.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_config_service_posix.h" #include "net/dns/dns_protocol.h" diff --git a/chromium/net/dns/dns_config_service_win_unittest.cc b/chromium/net/dns/dns_config_service_win_unittest.cc index 091c090a918..da062f1cc2f 100644 --- a/chromium/net/dns/dns_config_service_win_unittest.cc +++ b/chromium/net/dns/dns_config_service_win_unittest.cc @@ -7,6 +7,7 @@ #include "base/logging.h" #include "base/memory/free_deleter.h" #include "net/base/ip_address.h" +#include "net/base/ip_endpoint.h" #include "net/dns/dns_protocol.h" #include "testing/gtest/include/gtest/gtest.h" diff --git a/chromium/net/dns/dns_protocol.h b/chromium/net/dns/dns_protocol.h index 149316723a6..b022b91a594 100644 --- a/chromium/net/dns/dns_protocol.h +++ b/chromium/net/dns/dns_protocol.h @@ -76,12 +76,12 @@ static const uint16_t kDefaultPortMulticast = 5353; // On-the-wire header. All uint16_t are in network order. struct NET_EXPORT Header { - uint16_t id; - uint16_t flags; - uint16_t qdcount; - uint16_t ancount; - uint16_t nscount; - uint16_t arcount; + uint16_t id = 0; + uint16_t flags = 0; + uint16_t qdcount = 0; + uint16_t ancount = 0; + uint16_t nscount = 0; + uint16_t arcount = 0; }; #pragma pack(pop) @@ -141,6 +141,7 @@ static const uint8_t kRcodeREFUSED = 5; // // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-12 static const uint16_t kFlagResponse = 0x8000; +static const uint16_t kFlagAA = 0x400; // Authoritative Answer - response flag. static const uint16_t kFlagRD = 0x100; // Recursion Desired - query flag. static const uint16_t kFlagTC = 0x200; // Truncated - server flag. diff --git a/chromium/net/dns/dns_query.cc b/chromium/net/dns/dns_query.cc index 0229d1fbbec..a5e9726b437 100644 --- a/chromium/net/dns/dns_query.cc +++ b/chromium/net/dns/dns_query.cc @@ -75,12 +75,52 @@ DnsQuery::DnsQuery(uint16_t id, } } +DnsQuery::DnsQuery(scoped_refptr<IOBufferWithSize> buffer) + : io_buffer_(std::move(buffer)) {} + DnsQuery::~DnsQuery() = default; std::unique_ptr<DnsQuery> DnsQuery::CloneWithNewId(uint16_t id) const { return base::WrapUnique(new DnsQuery(*this, id)); } +bool DnsQuery::Parse() { + if (io_buffer_ == nullptr || io_buffer_->data() == nullptr) { + return false; + } + // We should only parse the query once if the query is constructed from a raw + // buffer. If we have constructed the query from data or the query is already + // parsed after constructed from a raw buffer, |header_| is not null. + DCHECK(header_ == nullptr); + base::BigEndianReader reader(io_buffer_->data(), io_buffer_->size()); + dns_protocol::Header header; + if (!ReadHeader(&reader, &header)) { + return false; + } + if (header.flags & dns_protocol::kFlagResponse) { + return false; + } + if (header.qdcount > 1) { + VLOG(1) << "Not supporting parsing a DNS query with multiple questions."; + return false; + } + std::string qname; + if (!ReadName(&reader, &qname)) { + return false; + } + uint16_t qtype; + uint16_t qclass; + if (!reader.ReadU16(&qtype) || !reader.ReadU16(&qclass) || + qclass != dns_protocol::kClassIN) { + return false; + } + // |io_buffer_| now contains the raw packet of a valid DNS query, we just + // need to properly initialize |qname_size_| and |header_|. + qname_size_ = qname.size(); + header_ = reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + return true; +} + uint16_t DnsQuery::id() const { return base::NetToHost16(header_->id); } @@ -113,4 +153,35 @@ DnsQuery::DnsQuery(const DnsQuery& orig, uint16_t id) { header_->id = base::HostToNet16(id); } +bool DnsQuery::ReadHeader(base::BigEndianReader* reader, + dns_protocol::Header* header) { + return ( + reader->ReadU16(&header->id) && reader->ReadU16(&header->flags) && + reader->ReadU16(&header->qdcount) && reader->ReadU16(&header->ancount) && + reader->ReadU16(&header->nscount) && reader->ReadU16(&header->arcount)); +} + +bool DnsQuery::ReadName(base::BigEndianReader* reader, std::string* out) { + DCHECK(out != nullptr); + out->clear(); + out->reserve(dns_protocol::kMaxNameLength); + uint8_t label_length; + if (!reader->ReadU8(&label_length)) { + return false; + } + out->append(reinterpret_cast<char*>(&label_length), 1); + while (label_length) { + base::StringPiece label; + if (!reader->ReadPiece(&label, label_length)) { + return false; + } + out->append(label.data(), label.size()); + if (!reader->ReadU8(&label_length)) { + return false; + } + out->append(reinterpret_cast<char*>(&label_length), 1); + } + return true; +} + } // namespace net diff --git a/chromium/net/dns/dns_query.h b/chromium/net/dns/dns_query.h index b68772e71a7..8e77ac57df8 100644 --- a/chromium/net/dns/dns_query.h +++ b/chromium/net/dns/dns_query.h @@ -15,13 +15,17 @@ #include "base/strings/string_piece.h" #include "net/base/net_export.h" +namespace base { +class BigEndianReader; +} // namespace base + namespace net { class OptRecordRdata; namespace dns_protocol { struct Header; -} +} // namespace dns_protocol class IOBufferWithSize; @@ -36,11 +40,22 @@ class NET_EXPORT_PRIVATE DnsQuery { const base::StringPiece& qname, uint16_t qtype, const OptRecordRdata* opt_rdata = nullptr); + + // Constructs an empty query from a raw packet in |buffer|. If the raw packet + // represents a valid DNS query in the wire format (RFC 1035), Parse() will + // populate the empty query. + DnsQuery(scoped_refptr<IOBufferWithSize> buffer); + ~DnsQuery(); // Clones |this| verbatim, with ID field of the header set to |id|. std::unique_ptr<DnsQuery> CloneWithNewId(uint16_t id) const; + // Returns true and populates the query if the internally stored raw packet + // can be parsed. This should only be called when DnsQuery is constructed from + // the raw buffer. + bool Parse(); + // DnsQuery field accessors. uint16_t id() const; base::StringPiece qname() const; @@ -50,7 +65,14 @@ class NET_EXPORT_PRIVATE DnsQuery { // response. base::StringPiece question() const; - // IOBuffer accessor to be used for writing out the query. + // Returns the size of the question section. + size_t question_size() const { + // QNAME + QTYPE + QCLASS + return qname_size_ + sizeof(uint16_t) + sizeof(uint16_t); + } + + // IOBuffer accessor to be used for writing out the query. The buffer has + // the same byte layout as the DNS query wire format. IOBufferWithSize* io_buffer() const { return io_buffer_.get(); } void set_flags(uint16_t flags); @@ -58,21 +80,21 @@ class NET_EXPORT_PRIVATE DnsQuery { private: DnsQuery(const DnsQuery& orig, uint16_t id); - // Returns the size of the question section. - size_t question_size() const { - // QNAME + QTYPE + QCLASS - return qname_size_ + sizeof(uint16_t) + sizeof(uint16_t); - } + bool ReadHeader(base::BigEndianReader* reader, dns_protocol::Header* out); + // After read, |out| is in the DNS format, e.g. + // "\x03""www""\x08""chromium""\x03""com""\x00". Use DNSDomainToString to + // convert to the dotted format "www.chromium.com" with no trailing dot. + bool ReadName(base::BigEndianReader* reader, std::string* out); // Size of the DNS name (*NOT* hostname) we are trying to resolve; used // to calculate offsets. - size_t qname_size_; + size_t qname_size_ = 0; // Contains query bytes to be consumed by higher level Write() call. scoped_refptr<IOBufferWithSize> io_buffer_; // Pointer to the dns header section. - dns_protocol::Header* header_; + dns_protocol::Header* header_ = nullptr; DISALLOW_COPY_AND_ASSIGN(DnsQuery); }; diff --git a/chromium/net/dns/dns_query_parse_fuzzer.cc b/chromium/net/dns/dns_query_parse_fuzzer.cc new file mode 100644 index 00000000000..f93621844c3 --- /dev/null +++ b/chromium/net/dns/dns_query_parse_fuzzer.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include <memory> + +#include "net/base/io_buffer.h" +#include "net/dns/dns_query.h" + +// Entry point for LibFuzzer. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + auto packet = base::MakeRefCounted<net::IOBufferWithSize>(size); + memcpy(packet->data(), data, size); + auto out = std::make_unique<net::DnsQuery>(packet); + out->Parse(); + return 0; +} diff --git a/chromium/net/dns/dns_query_unittest.cc b/chromium/net/dns/dns_query_unittest.cc index 7ee93008e88..0638bb184c0 100644 --- a/chromium/net/dns/dns_query_unittest.cc +++ b/chromium/net/dns/dns_query_unittest.cc @@ -4,6 +4,7 @@ #include "net/dns/dns_query.h" +#include "base/stl_util.h" #include "net/base/io_buffer.h" #include "net/dns/dns_protocol.h" #include "net/dns/record_rdata.h" @@ -20,9 +21,26 @@ std::tuple<char*, size_t> AsTuple(const IOBufferWithSize* buf) { return std::make_tuple(buf->data(), buf->size()); } +bool ParseAndCreateDnsQueryFromRawPacket(const uint8_t* data, + size_t length, + std::unique_ptr<DnsQuery>* out) { + auto packet = base::MakeRefCounted<IOBufferWithSize>(length); + memcpy(packet->data(), data, length); + out->reset(new DnsQuery(packet)); + return (*out)->Parse(); +} + +// This includes \0 at the end. +const char kQNameData[] = + "\x03" + "www" + "\x07" + "example" + "\x03" + "com"; + TEST(DnsQueryTest, Constructor) { // This includes \0 at the end. - const char qname_data[] = "\x03""www""\x07""example""\x03""com"; const uint8_t query_data[] = { // Header 0xbe, 0xef, 0x01, 0x00, // Flags -- set RD (recursion desired) bit. @@ -38,7 +56,7 @@ TEST(DnsQueryTest, Constructor) { 0x00, 0x01, // QCLASS: IN class. }; - base::StringPiece qname(qname_data, sizeof(qname_data)); + base::StringPiece qname(kQNameData, sizeof(kQNameData)); DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA); EXPECT_EQ(dns_protocol::kTypeA, q1.qtype()); EXPECT_THAT(AsTuple(q1.io_buffer()), ElementsAreArray(query_data)); @@ -50,9 +68,7 @@ TEST(DnsQueryTest, Constructor) { } TEST(DnsQueryTest, Clone) { - // This includes \0 at the end. - const char qname_data[] = "\x03""www""\x07""example""\x03""com"; - base::StringPiece qname(qname_data, sizeof(qname_data)); + base::StringPiece qname(kQNameData, sizeof(kQNameData)); DnsQuery q1(0, qname, dns_protocol::kTypeA); EXPECT_EQ(0, q1.id()); @@ -64,14 +80,6 @@ TEST(DnsQueryTest, Clone) { } TEST(DnsQueryTest, EDNS0) { - // This includes \0 at the end. - const char qname_data[] = - "\x03" - "www" - "\x07" - "example" - "\x03" - "com"; const uint8_t query_data[] = { // Header 0xbe, 0xef, 0x01, 0x00, // Flags -- set RD (recursion desired) bit. @@ -96,7 +104,7 @@ TEST(DnsQueryTest, EDNS0) { 0xDE, 0xAD, 0xBE, 0xEF // OPT data }; - base::StringPiece qname(qname_data, sizeof(qname_data)); + base::StringPiece qname(kQNameData, sizeof(kQNameData)); OptRecordRdata opt_rdata; opt_rdata.AddOpt(OptRecordRdata::Opt(255, "\xde\xad\xbe\xef")); DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA, &opt_rdata); @@ -109,6 +117,126 @@ TEST(DnsQueryTest, EDNS0) { EXPECT_EQ(question, q1.question()); } +TEST(DnsQueryParseTest, SingleQuestionForTypeARecord) { + const uint8_t query_data[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x01, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + }; + std::unique_ptr<DnsQuery> query; + EXPECT_TRUE(ParseAndCreateDnsQueryFromRawPacket(query_data, + sizeof(query_data), &query)); + EXPECT_EQ(0x1234, query->id()); + base::StringPiece qname(kQNameData, sizeof(kQNameData)); + EXPECT_EQ(qname, query->qname()); + EXPECT_EQ(dns_protocol::kTypeA, query->qtype()); +} + +TEST(DnsQueryParseTest, SingleQuestionForTypeAAAARecord) { + const uint8_t query_data[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x01, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x1c, // type AAAA Record + 0x00, 0x01, // class IN + }; + std::unique_ptr<DnsQuery> query; + EXPECT_TRUE(ParseAndCreateDnsQueryFromRawPacket(query_data, + sizeof(query_data), &query)); + EXPECT_EQ(0x1234, query->id()); + base::StringPiece qname(kQNameData, sizeof(kQNameData)); + EXPECT_EQ(qname, query->qname()); + EXPECT_EQ(dns_protocol::kTypeAAAA, query->qtype()); +} + +const uint8_t kQueryTruncatedQuestion[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x02, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, // class IN, truncated +}; + +const uint8_t kQueryTwoQuestions[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x02, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'o', 'r', 'g', + 0x00, // null label + 0x00, 0x1c, // type AAAA Record + 0x00, 0x01, // class IN +}; + +const uint8_t kQueryInvalidDNSDomainName1[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x01, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x02, 'w', 'w', 'w', // wrong label length + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN +}; + +const uint8_t kQueryInvalidDNSDomainName2[] = { + 0x12, 0x34, // ID + 0x00, 0x00, // flags + 0x00, 0x01, // number of questions + 0x00, 0x00, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0xc0, 0x02, // illegal name pointer + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN +}; + +TEST(DnsQueryParseTest, FailsInvalidQueries) { + const struct TestCase { + const uint8_t* data; + size_t size; + } testcases[] = { + {kQueryTruncatedQuestion, base::size(kQueryTruncatedQuestion)}, + {kQueryTwoQuestions, base::size(kQueryTwoQuestions)}, + {kQueryInvalidDNSDomainName1, base::size(kQueryInvalidDNSDomainName1)}, + {kQueryInvalidDNSDomainName2, base::size(kQueryInvalidDNSDomainName2)}}; + std::unique_ptr<DnsQuery> query; + for (const auto& testcase : testcases) { + EXPECT_FALSE(ParseAndCreateDnsQueryFromRawPacket(testcase.data, + testcase.size, &query)); + } +} + } // namespace } // namespace net diff --git a/chromium/net/dns/dns_response.cc b/chromium/net/dns/dns_response.cc index c1fa3de5619..7177dcafd35 100644 --- a/chromium/net/dns/dns_response.cc +++ b/chromium/net/dns/dns_response.cc @@ -5,6 +5,8 @@ #include "net/dns/dns_response.h" #include <limits> +#include <numeric> +#include <vector> #include "base/big_endian.h" #include "base/strings/string_util.h" @@ -16,6 +18,7 @@ #include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_util.h" +#include "net/dns/record_rdata.h" namespace net { @@ -25,10 +28,16 @@ const size_t kHeaderSize = sizeof(dns_protocol::Header); const uint8_t kRcodeMask = 0xf; +// RFC 1035, Section 4.1.3. +// TYPE (2 bytes) + CLASS (2 bytes) + TTL (4 bytes) + RDLENGTH (2 bytes) +const size_t kResourceRecordSizeInBytesWithoutNameAndRData = 10; + } // namespace DnsResourceRecord::DnsResourceRecord() = default; +DnsResourceRecord::DnsResourceRecord(const DnsResourceRecord& other) = default; + DnsResourceRecord::~DnsResourceRecord() = default; DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) { @@ -150,12 +159,87 @@ bool DnsRecordParser::SkipQuestion() { return true; } +DnsResponse::DnsResponse( + uint16_t id, + bool is_authoritative, + const std::vector<DnsResourceRecord>& answers, + const std::vector<DnsResourceRecord>& additional_records, + const base::Optional<DnsQuery>& query) { + bool has_query = query.has_value(); + dns_protocol::Header header; + header.id = id; + bool success = true; + if (has_query) { + success &= (id == query.value().id()); + DCHECK(success); + // DnsQuery only supports a single question. + header.qdcount = 1; + } + header.flags |= dns_protocol::kFlagResponse; + if (is_authoritative) { + header.flags |= dns_protocol::kFlagAA; + } + header.ancount = answers.size(); + header.arcount = additional_records.size(); + + // Response starts with the header and the question section (if any). + size_t response_size = has_query + ? sizeof(header) + query.value().question_size() + : sizeof(header); + // Add the size of all answers and additional records. + auto do_accumulation = [](size_t cur_size, const DnsResourceRecord& answer) { + bool has_final_dot = answer.name.back() == '.'; + // Depending on if answer.name in the dotted format has the final dot + // for the root domain or not, the corresponding DNS domain name format + // to be written to rdata is 1 byte (with dot) or 2 bytes larger in + // size. See RFC 1035, Section 3.1 and DNSDomainFromDot. + return cur_size + answer.name.size() + (has_final_dot ? 1 : 2) + + kResourceRecordSizeInBytesWithoutNameAndRData + answer.rdata.size(); + }; + response_size = std::accumulate(answers.begin(), answers.end(), response_size, + do_accumulation); + + response_size = + std::accumulate(additional_records.begin(), additional_records.end(), + response_size, do_accumulation); + + io_buffer_ = base::MakeRefCounted<IOBuffer>(response_size); + io_buffer_size_ = response_size; + base::BigEndianWriter writer(io_buffer_->data(), io_buffer_size_); + success &= WriteHeader(&writer, header); + DCHECK(success); + if (has_query) { + success &= WriteQuestion(&writer, query.value()); + DCHECK(success); + } + // Start the Answer section. + for (const auto& answer : answers) { + success &= WriteAnswer(&writer, answer, query); + DCHECK(success); + } + // Start the Additional section. + for (const auto& record : additional_records) { + success &= WriteRecord(&writer, record); + DCHECK(success); + } + if (!success) { + io_buffer_.reset(); + io_buffer_size_ = 0; + return; + } + if (has_query) { + InitParse(io_buffer_size_, query.value()); + } else { + InitParseWithoutQuery(io_buffer_size_); + } +} + DnsResponse::DnsResponse() : io_buffer_(base::MakeRefCounted<IOBuffer>(dns_protocol::kMaxUDPSize + 1)), io_buffer_size_(dns_protocol::kMaxUDPSize + 1) {} -DnsResponse::DnsResponse(IOBuffer* buffer, size_t size) - : io_buffer_(buffer), io_buffer_size_(size) {} +DnsResponse::DnsResponse(scoped_refptr<IOBuffer> buffer, size_t size) + : io_buffer_(std::move(buffer)), io_buffer_size_(size) {} DnsResponse::DnsResponse(size_t length) : io_buffer_(base::MakeRefCounted<IOBuffer>(length)), @@ -174,7 +258,7 @@ DnsResponse::~DnsResponse() = default; bool DnsResponse::InitParse(size_t nbytes, const DnsQuery& query) { // Response includes query, it should be at least that size. if (nbytes < static_cast<size_t>(query.io_buffer()->size()) || - nbytes >= io_buffer_size_) { + nbytes > io_buffer_size_) { return false; } @@ -200,7 +284,7 @@ bool DnsResponse::InitParse(size_t nbytes, const DnsQuery& query) { } bool DnsResponse::InitParseWithoutQuery(size_t nbytes) { - if (nbytes < kHeaderSize || nbytes >= io_buffer_size_) { + if (nbytes < kHeaderSize || nbytes > io_buffer_size_) { return false; } @@ -350,4 +434,45 @@ DnsResponse::Result DnsResponse::ParseToAddressList( return DNS_PARSE_OK; } +bool DnsResponse::WriteHeader(base::BigEndianWriter* writer, + const dns_protocol::Header& header) { + return writer->WriteU16(header.id) && writer->WriteU16(header.flags) && + writer->WriteU16(header.qdcount) && writer->WriteU16(header.ancount) && + writer->WriteU16(header.nscount) && writer->WriteU16(header.arcount); +} + +bool DnsResponse::WriteQuestion(base::BigEndianWriter* writer, + const DnsQuery& query) { + const base::StringPiece& question = query.question(); + return writer->WriteBytes(question.data(), question.size()); +} + +bool DnsResponse::WriteRecord(base::BigEndianWriter* writer, + const DnsResourceRecord& record) { + if (!RecordRdata::HasValidSize(record.rdata, record.type)) { + VLOG(1) << "Invalid RDATA size for a record."; + return false; + } + std::string domain_name; + if (!DNSDomainFromDot(record.name, &domain_name)) { + VLOG(1) << "Invalid dotted name."; + return false; + } + return writer->WriteBytes(domain_name.data(), domain_name.size()) && + writer->WriteU16(record.type) && writer->WriteU16(record.klass) && + writer->WriteU32(record.ttl) && + writer->WriteU16(record.rdata.size()) && + writer->WriteBytes(record.rdata.data(), record.rdata.size()); +} + +bool DnsResponse::WriteAnswer(base::BigEndianWriter* writer, + const DnsResourceRecord& answer, + const base::Optional<DnsQuery>& query) { + if (query.has_value() && answer.type != query.value().qtype()) { + VLOG(1) << "Mismatched answer resource record type and qtype."; + return false; + } + return WriteRecord(writer, answer); +} + } // namespace net diff --git a/chromium/net/dns/dns_response.h b/chromium/net/dns/dns_response.h index ee533c939c8..c7962955178 100644 --- a/chromium/net/dns/dns_response.h +++ b/chromium/net/dns/dns_response.h @@ -12,10 +12,15 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" +#include "base/optional.h" #include "base/strings/string_piece.h" #include "base/time/time.h" #include "net/base/net_export.h" +namespace base { +class BigEndianWriter; +} // namespace base + namespace net { class AddressList; @@ -24,18 +29,19 @@ class IOBuffer; namespace dns_protocol { struct Header; -} +} // namespace dns_protocol // Structure representing a Resource Record as specified in RFC 1035, Section // 4.1.3. struct NET_EXPORT_PRIVATE DnsResourceRecord { DnsResourceRecord(); + explicit DnsResourceRecord(const DnsResourceRecord& other); ~DnsResourceRecord(); std::string name; // in dotted form - uint16_t type; - uint16_t klass; - uint32_t ttl; + uint16_t type = 0; + uint16_t klass = 0; + uint32_t ttl = 0; base::StringPiece rdata; // points to the original response buffer }; @@ -105,11 +111,19 @@ class NET_EXPORT_PRIVATE DnsResponse { // largest possible response, to detect malformed responses. DnsResponse(); + // Constructs a response message from |answers| and the originating |query|. + // After the successful construction, and the parser is also initialized. + DnsResponse(uint16_t id, + bool is_authoritative, + const std::vector<DnsResourceRecord>& answers, + const std::vector<DnsResourceRecord>& additional_records, + const base::Optional<DnsQuery>& query); + // Constructs a response buffer of given length. Used for TCP transactions. explicit DnsResponse(size_t length); - // Constructs a response taking ownership of the passed buffer. - DnsResponse(IOBuffer* buffer, size_t size); + // Constructs a response from the passed buffer. + DnsResponse(scoped_refptr<IOBuffer> buffer, size_t size); // Constructs a response from |data|. Used for testing purposes only! DnsResponse(const void* data, size_t length, size_t answer_offset); @@ -124,14 +138,17 @@ class NET_EXPORT_PRIVATE DnsResponse { size_t io_buffer_size() const { return io_buffer_size_; } // Assuming the internal buffer holds |nbytes| bytes, returns true iff the - // packet matches the |query| id and question. + // packet matches the |query| id and question. This should only be called if + // the response is constructed from a raw buffer. bool InitParse(size_t nbytes, const DnsQuery& query); // Assuming the internal buffer holds |nbytes| bytes, initialize the parser - // without matching it against an existing query. + // without matching it against an existing query. This should only be called + // if the response is constructed from a raw buffer. bool InitParseWithoutQuery(size_t nbytes); - // Returns true if response is valid, that is, after successful InitParse. + // Returns true if response is valid, that is, after successful InitParse, or + // after successful construction of a new response from data. bool IsValid() const; // All of the methods below are valid only if the response is valid. @@ -160,6 +177,15 @@ class NET_EXPORT_PRIVATE DnsResponse { Result ParseToAddressList(AddressList* addr_list, base::TimeDelta* ttl) const; private: + bool WriteHeader(base::BigEndianWriter* writer, + const dns_protocol::Header& header); + bool WriteQuestion(base::BigEndianWriter* writer, const DnsQuery& query); + bool WriteRecord(base::BigEndianWriter* wirter, + const DnsResourceRecord& record); + bool WriteAnswer(base::BigEndianWriter* wirter, + const DnsResourceRecord& answer, + const base::Optional<DnsQuery>& query); + // Convenience for header access. const dns_protocol::Header* header() const; diff --git a/chromium/net/dns/dns_response_fuzzer.cc b/chromium/net/dns/dns_response_fuzzer.cc new file mode 100644 index 00000000000..c53e179a5ba --- /dev/null +++ b/chromium/net/dns/dns_response_fuzzer.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include "base/strings/string_number_conversions.h" +#include "base/strings/string_piece.h" +#include "net/base/io_buffer.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" + +// Entry point for LibFuzzer. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + auto packet = base::MakeRefCounted<net::IOBufferWithSize>(size); + memcpy(packet->data(), data, size); + base::Optional<net::DnsQuery> query; + query.emplace(packet); + if (!query->Parse()) { + return 0; + } + net::DnsResponse response(query->id(), true /* is_authoritative */, + {} /* answers */, {} /* additional records */, + query); + std::string out = + base::HexEncode(response.io_buffer()->data(), response.io_buffer_size()); + return 0; +} diff --git a/chromium/net/dns/dns_response_unittest.cc b/chromium/net/dns/dns_response_unittest.cc index af3f1d9c66d..7ec9410d6d4 100644 --- a/chromium/net/dns/dns_response_unittest.cc +++ b/chromium/net/dns/dns_response_unittest.cc @@ -4,12 +4,16 @@ #include "net/dns/dns_response.h" +#include "base/big_endian.h" +#include "base/optional.h" #include "base/time/time.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_test_util.h" +#include "net/dns/dns_util.h" +#include "net/dns/record_rdata.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -569,6 +573,373 @@ TEST(DnsResponseTest, ParseToAddressListFail) { } } +TEST(DnsResponseWriteTest, SingleARecordAnswer) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x00, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + }; + net::DnsResourceRecord answer; + answer.name = "www.example.com"; + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + DnsResponse response(0x1234 /* response_id */, true /* is_authoritative*/, + answers, {} /* additional records */, base::nullopt); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, SingleARecordAnswerWithFinalDotInName) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x00, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + }; + net::DnsResourceRecord answer; + answer.name = "www.example.com."; // FQDN with the final dot. + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + DnsResponse response(0x1234 /* response_id */, true /* is_authoritative*/, + answers, {} /* additional records */, base::nullopt); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, SingleARecordAnswerWithQuestion) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x01, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + }; + std::string dotted_name("www.example.com"); + std::string dns_name; + ASSERT_TRUE(DNSDomainFromDot(dotted_name, &dns_name)); + OptRecordRdata opt_rdata; + opt_rdata.AddOpt(OptRecordRdata::Opt(255, "\xde\xad\xbe\xef")); + base::Optional<DnsQuery> query; + query.emplace(0x1234 /* id */, dns_name, dns_protocol::kTypeA, &opt_rdata); + net::DnsResourceRecord answer; + answer.name = dotted_name; + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + DnsResponse response(0x1234 /* id */, true /* is_authoritative*/, answers, + {} /* additional records */, query); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, + SingleAnswerWithQuestionConstructedFromSizeInflatedQuery) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x01, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + }; + std::string dotted_name("www.example.com"); + std::string dns_name; + ASSERT_TRUE(DNSDomainFromDot(dotted_name, &dns_name)); + size_t buf_size = + sizeof(dns_protocol::Header) + dns_name.size() + 2 /* qtype */ + + 2 /* qclass */ + + 10 /* extra bytes that inflate the internal buffer of a query */; + auto buf = base::MakeRefCounted<IOBufferWithSize>(buf_size); + memset(buf->data(), 0, buf->size()); + base::BigEndianWriter writer(buf->data(), buf_size); + writer.WriteU16(0x1234); // id + writer.WriteU16(0); // flags, is query + writer.WriteU16(1); // qdcount + writer.WriteU16(0); // ancount + writer.WriteU16(0); // nscount + writer.WriteU16(0); // arcount + writer.WriteBytes(dns_name.data(), dns_name.size()); // qname + writer.WriteU16(dns_protocol::kTypeA); // qtype + writer.WriteU16(dns_protocol::kClassIN); // qclass + // buf contains 10 extra zero bytes. + base::Optional<DnsQuery> query; + query.emplace(buf); + query->Parse(); + net::DnsResourceRecord answer; + answer.name = dotted_name; + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + DnsResponse response(0x1234 /* id */, true /* is_authoritative*/, answers, + {} /* additional records */, query); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, SingleQuadARecordAnswer) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x00, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x1c, // type AAAA Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x10, // rdlength, 128 bits + 0xfd, 0x12, 0x34, 0x56, 0x78, 0x9a, 0x00, 0x01, // fd12:3456:789a:1::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }; + net::DnsResourceRecord answer; + answer.name = "www.example.com"; + answer.type = dns_protocol::kTypeAAAA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece( + "\xfd\x12\x34\x56\x78\x9a\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01", 16); + std::vector<DnsResourceRecord> answers(1, answer); + DnsResponse response(0x1234 /* id */, true /* is_authoritative*/, answers, + {} /* additional records */, base::nullopt); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, + SingleARecordAnswerWithQuestionAndNsecAdditionalRecord) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x01, // number of questions + 0x00, 0x01, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x01, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', + 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x2f, // type NSEC Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x05, // rdlength, 5 bytes + 0xc0, 0x0c, // pointer to the previous "www.example.com" + 0x00, 0x01, 0x40, // type bit map of type A: window block 0, bitmap + // length 1, bitmap with bit 1 set + }; + std::string dotted_name("www.example.com"); + std::string dns_name; + ASSERT_TRUE(DNSDomainFromDot(dotted_name, &dns_name)); + base::Optional<DnsQuery> query; + query.emplace(0x1234 /* id */, dns_name, dns_protocol::kTypeA); + net::DnsResourceRecord answer; + answer.name = dotted_name; + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + net::DnsResourceRecord additional_record; + additional_record.name = dotted_name; + additional_record.type = dns_protocol::kTypeNSEC; + additional_record.klass = dns_protocol::kClassIN; + additional_record.ttl = 120; // 120 seconds. + // Bitmap for "www.example.com" with type A set. + additional_record.rdata = base::StringPiece("\xc0\x0c\x00\x01\x40", 5); + std::vector<DnsResourceRecord> additional_records(1, additional_record); + DnsResponse response(0x1234 /* id */, true /* is_authoritative*/, answers, + additional_records, query); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, TwoAnswersWithAAndQuadARecords) { + const char response_data[] = { + 0x12, 0x34, // ID + 0x84, 0x00, // flags, response with authoritative answer + 0x00, 0x00, // number of questions + 0x00, 0x02, // number of answer rr + 0x00, 0x00, // number of name server rr + 0x00, 0x00, // number of additional rr + 0x03, 'w', 'w', 'w', 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, // null label + 0x00, 0x01, // type A Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds + 0x00, 0x04, // rdlength, 32 bits + 0xc0, 0xa8, 0x00, 0x01, // 192.168.0.1 + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'o', 'r', 'g', + 0x00, // null label + 0x00, 0x1c, // type AAAA Record + 0x00, 0x01, // class IN + 0x00, 0x00, 0x00, 0x3c, // TTL, 60 seconds + 0x00, 0x10, // rdlength, 128 bits + 0xfd, 0x12, 0x34, 0x56, 0x78, 0x9a, 0x00, 0x01, // fd12:3456:789a:1::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }; + net::DnsResourceRecord answer1; + answer1.name = "www.example.com"; + answer1.type = dns_protocol::kTypeA; + answer1.klass = dns_protocol::kClassIN; + answer1.ttl = 120; // 120 seconds. + answer1.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + net::DnsResourceRecord answer2; + answer2.name = "example.org"; + answer2.type = dns_protocol::kTypeAAAA; + answer2.klass = dns_protocol::kClassIN; + answer2.ttl = 60; + answer2.rdata = base::StringPiece( + "\xfd\x12\x34\x56\x78\x9a\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01", 16); + std::vector<DnsResourceRecord> answers(2); + answers[0] = answer1; + answers[1] = answer2; + DnsResponse response(0x1234 /* id */, true /* is_authoritative*/, answers, + {} /* additional records */, base::nullopt); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + std::string expected_response(response_data, sizeof(response_data)); + std::string actual_response(response.io_buffer()->data(), + response.io_buffer_size()); + EXPECT_EQ(expected_response, actual_response); +} + +TEST(DnsResponseWriteTest, WrittenResponseCanBeParsed) { + std::string dotted_name("www.example.com"); + net::DnsResourceRecord answer; + answer.name = dotted_name; + answer.type = dns_protocol::kTypeA; + answer.klass = dns_protocol::kClassIN; + answer.ttl = 120; // 120 seconds. + answer.rdata = base::StringPiece("\xc0\xa8\x00\x01", 4); + std::vector<DnsResourceRecord> answers(1, answer); + net::DnsResourceRecord additional_record; + additional_record.name = dotted_name; + additional_record.type = dns_protocol::kTypeNSEC; + additional_record.klass = dns_protocol::kClassIN; + additional_record.ttl = 120; // 120 seconds. + additional_record.rdata = base::StringPiece("\xc0\x0c\x00\x01\x04", 5); + std::vector<DnsResourceRecord> additional_records(1, additional_record); + DnsResponse response(0x1234 /* response_id */, true /* is_authoritative*/, + answers, additional_records, base::nullopt); + ASSERT_NE(nullptr, response.io_buffer()); + EXPECT_TRUE(response.IsValid()); + EXPECT_EQ(1u, response.answer_count()); + EXPECT_EQ(1u, response.additional_answer_count()); + auto parser = response.Parser(); + net::DnsResourceRecord parsed_record; + EXPECT_TRUE(parser.ReadRecord(&parsed_record)); + // Answer with an A record. + EXPECT_EQ(answer.name, parsed_record.name); + EXPECT_EQ(answer.type, parsed_record.type); + EXPECT_EQ(answer.klass, parsed_record.klass); + EXPECT_EQ(answer.ttl, parsed_record.ttl); + EXPECT_EQ(answer.rdata, parsed_record.rdata); + // Additional NSEC record. + EXPECT_TRUE(parser.ReadRecord(&parsed_record)); + EXPECT_EQ(additional_record.name, parsed_record.name); + EXPECT_EQ(additional_record.type, parsed_record.type); + EXPECT_EQ(additional_record.klass, parsed_record.klass); + EXPECT_EQ(additional_record.ttl, parsed_record.ttl); + EXPECT_EQ(additional_record.rdata, parsed_record.rdata); +} + } // namespace } // namespace net diff --git a/chromium/net/dns/dns_session.cc b/chromium/net/dns/dns_session.cc index 85c7b09f622..0e8eab84ca9 100644 --- a/chromium/net/dns/dns_session.cc +++ b/chromium/net/dns/dns_session.cc @@ -20,7 +20,6 @@ #include "base/time/time.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/dns/dns_config_service.h" #include "net/dns/dns_socket_pool.h" #include "net/dns/dns_util.h" #include "net/log/net_log_event_type.h" diff --git a/chromium/net/dns/dns_session.h b/chromium/net/dns/dns_session.h index 803a64d641a..6212ff470a8 100644 --- a/chromium/net/dns/dns_session.h +++ b/chromium/net/dns/dns_session.h @@ -18,7 +18,7 @@ #include "net/base/net_export.h" #include "net/base/network_change_notifier.h" #include "net/base/rand_callback.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_socket_pool.h" namespace base { diff --git a/chromium/net/dns/dns_test_util.cc b/chromium/net/dns/dns_test_util.cc index ac98c638731..acc3829f432 100644 --- a/chromium/net/dns/dns_test_util.cc +++ b/chromium/net/dns/dns_test_util.cc @@ -208,7 +208,7 @@ class MockTransactionFactory : public DnsTransactionFactory { void CompleteDelayedTransactions() { DelayedTransactionList old_delayed_transactions; old_delayed_transactions.swap(delayed_transactions_); - for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); + for (auto it = old_delayed_transactions.begin(); it != old_delayed_transactions.end(); ++it) { if (it->get()) (*it)->FinishDelayedTransaction(); diff --git a/chromium/net/dns/dns_test_util.h b/chromium/net/dns/dns_test_util.h index ffdceb90dab..5ebcdeb7478 100644 --- a/chromium/net/dns/dns_test_util.h +++ b/chromium/net/dns/dns_test_util.h @@ -13,7 +13,7 @@ #include <vector> #include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_protocol.h" namespace net { diff --git a/chromium/net/dns/dns_transaction.cc b/chromium/net/dns/dns_transaction.cc index a88449f6444..30d679c1bcf 100644 --- a/chromium/net/dns/dns_transaction.cc +++ b/chromium/net/dns/dns_transaction.cc @@ -40,6 +40,7 @@ #include "net/base/load_flags.h" #include "net/base/net_errors.h" #include "net/base/upload_bytes_element_reader.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response.h" @@ -524,7 +525,7 @@ class DnsHTTPAttempt : public DnsAttempt, public URLRequest::Delegate { buffer_->set_offset(0); if (size == 0u) return ERR_DNS_MALFORMED_RESPONSE; - response_ = std::make_unique<DnsResponse>(buffer_.get(), size + 1); + response_ = std::make_unique<DnsResponse>(buffer_, size + 1); if (!response_->InitParse(size, *query_)) return ERR_DNS_MALFORMED_RESPONSE; if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN) diff --git a/chromium/net/dns/dns_transaction_unittest.cc b/chromium/net/dns/dns_transaction_unittest.cc index 910e0160112..779ae166008 100644 --- a/chromium/net/dns/dns_transaction_unittest.cc +++ b/chromium/net/dns/dns_transaction_unittest.cc @@ -27,6 +27,7 @@ #include "net/base/port_util.h" #include "net/base/upload_bytes_element_reader.h" #include "net/base/url_util.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response.h" diff --git a/chromium/net/dns/fuzzed_host_resolver.cc b/chromium/net/dns/fuzzed_host_resolver.cc index d93c89fefe2..8290c82a588 100644 --- a/chromium/net/dns/fuzzed_host_resolver.cc +++ b/chromium/net/dns/fuzzed_host_resolver.cc @@ -21,7 +21,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_hosts.h" namespace net { diff --git a/chromium/net/dns/host_cache.cc b/chromium/net/dns/host_cache.cc index bb8e40a9ffb..e377c616a5b 100644 --- a/chromium/net/dns/host_cache.cc +++ b/chromium/net/dns/host_cache.cc @@ -43,8 +43,7 @@ const char kAddressesKey[] = "addresses"; bool AddressListFromListValue(const base::ListValue* value, AddressList* list) { list->clear(); - for (base::ListValue::const_iterator it = value->begin(); it != value->end(); - it++) { + for (auto it = value->begin(); it != value->end(); it++) { IPAddress address; std::string addr_string; if (!it->GetAsString(&addr_string) || @@ -290,8 +289,8 @@ void HostCache::ClearForHosts( bool changed = false; base::TimeTicks now = tick_clock_->NowTicks(); - for (EntryMap::iterator it = entries_.begin(); it != entries_.end();) { - EntryMap::iterator next_it = std::next(it); + for (auto it = entries_.begin(); it != entries_.end();) { + auto next_it = std::next(it); if (host_filter.Run(it->first.hostname)) { RecordErase(ERASE_CLEAR, now, it->second); diff --git a/chromium/net/dns/host_resolver.cc b/chromium/net/dns/host_resolver.cc index 77e55f346b2..24e2f4d1b71 100644 --- a/chromium/net/dns/host_resolver.cc +++ b/chromium/net/dns/host_resolver.cc @@ -11,7 +11,6 @@ #include "base/values.h" #include "net/base/net_errors.h" #include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" #include "net/dns/host_cache.h" #include "net/dns/host_resolver_impl.h" @@ -128,6 +127,12 @@ bool HostResolver::GetNoIPv6OnWifi() { return false; } +void HostResolver::SetDnsConfigOverrides(const DnsConfigOverrides& overrides) { + // Should be overridden in any HostResolver implementation where this method + // may be called. + NOTREACHED(); +} + const std::vector<DnsConfig::DnsOverHttpsServerConfig>* HostResolver::GetDnsOverHttpsServersForTesting() const { return nullptr; diff --git a/chromium/net/dns/host_resolver.h b/chromium/net/dns/host_resolver.h index 0f9a85bdc16..4c615bef029 100644 --- a/chromium/net/dns/host_resolver.h +++ b/chromium/net/dns/host_resolver.h @@ -18,7 +18,7 @@ #include "net/base/host_port_pair.h" #include "net/base/prioritized_dispatcher.h" #include "net/base/request_priority.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" #include "net/dns/host_cache.h" #include "net/dns/host_resolver_source.h" @@ -30,6 +30,7 @@ namespace net { class AddressList; class DnsClient; +struct DnsConfigOverrides; class HostResolverImpl; class NetLog; class NetLogWithSource; @@ -330,9 +331,11 @@ class NET_EXPORT HostResolver { virtual void SetNoIPv6OnWifi(bool no_ipv6_on_wifi); virtual bool GetNoIPv6OnWifi(); + // Sets overriding configuration that will replace or add to configuration + // read from the system for DnsClient resolution. + virtual void SetDnsConfigOverrides(const DnsConfigOverrides& overrides); + virtual void SetRequestContext(URLRequestContext* request_context) {} - virtual void AddDnsOverHttpsServer(std::string spec, bool use_post) {} - virtual void ClearDnsOverHttpsServers() {} // Returns the currently configured DNS over HTTPS servers. Returns nullptr if // DNS over HTTPS is not enabled. diff --git a/chromium/net/dns/host_resolver_impl.cc b/chromium/net/dns/host_resolver_impl.cc index 1a9366beb99..357e0617b9e 100644 --- a/chromium/net/dns/host_resolver_impl.cc +++ b/chromium/net/dns/host_resolver_impl.cc @@ -59,13 +59,14 @@ #include "net/base/url_util.h" #include "net/dns/address_sorter.h" #include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" #include "net/dns/dns_reloader.h" #include "net/dns/dns_response.h" #include "net/dns/dns_transaction.h" #include "net/dns/dns_util.h" +#include "net/dns/host_resolver_mdns_task.h" #include "net/dns/host_resolver_proc.h" +#include "net/dns/mdns_client.h" #include "net/log/net_log.h" #include "net/log/net_log_capture_mode.h" #include "net/log/net_log_event_type.h" @@ -77,6 +78,10 @@ #include "net/socket/datagram_client_socket.h" #include "url/url_canon_ip.h" +#if BUILDFLAG(ENABLE_MDNS) +#include "net/dns/mdns_client_impl.h" +#endif + #if defined(OS_WIN) #include "net/base/winsock_init.h" #endif @@ -1513,7 +1518,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, // This will destroy the Job. CompleteRequests( MakeCacheEntry(OK, addr_list, HostCache::Entry::SOURCE_HOSTS), - base::TimeDelta()); + base::TimeDelta(), true /* allow_cache */); return true; } return false; @@ -1526,7 +1531,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, } bool is_running() const { - return is_dns_running() || is_proc_running(); + return is_dns_running() || is_mdns_running() || is_proc_running(); } private: @@ -1618,7 +1623,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, switch (key_.host_resolver_source) { case HostResolverSource::ANY: if (resolver_->HaveDnsConfig() && - !ResemblesMulticastDNSName(key_.hostname)) { + !ResemblesMulticastDNSName(key_.hostname) && + !(key_.host_resolver_flags & HOST_RESOLVER_CANONNAME)) { StartDnsTask(); } else { StartProcTask(); @@ -1634,6 +1640,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, StartDnsTask(); break; + case HostResolverSource::MULTICAST_DNS: + StartMdnsTask(); + break; } // Caution: Job::Start must not complete synchronously. @@ -1644,7 +1653,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, // TaskScheduler threads low, we will need to use an "inner" // PrioritizedDispatcher with tighter limits. void StartProcTask() { - DCHECK(!is_dns_running()); + DCHECK(!is_running()); proc_task_ = std::make_unique<ProcTask>( key_, resolver_->proc_params_, base::BindOnce(&Job::OnProcTaskComplete, base::Unretained(this), @@ -1694,7 +1703,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, // Don't store the |ttl| in cache since it's not obtained from the server. CompleteRequests( MakeCacheEntry(net_error, addr_list, HostCache::Entry::SOURCE_UNKNOWN), - ttl); + ttl, true /* allow_cache */); } void StartDnsTask() { @@ -1752,7 +1761,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, CompleteRequests( HostCache::Entry(net_error, AddressList(), HostCache::Entry::Source::SOURCE_UNKNOWN, ttl), - ttl); + ttl, true /* allow_cache */); } } @@ -1785,7 +1794,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, } else { CompleteRequests(MakeCacheEntryWithTTL(net_error, addr_list, HostCache::Entry::SOURCE_DNS, ttl), - bounded_ttl); + bounded_ttl, true /* allow_cache */); } } @@ -1802,6 +1811,50 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, dns_task_->StartSecondTransaction(); } + void StartMdnsTask() { + DCHECK(!is_running()); + + // No flags are supported for MDNS except + // HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6 (which is not actually an + // input flag). + DCHECK_EQ(0, key_.host_resolver_flags & + ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6); + + std::vector<HostResolver::DnsQueryType> query_types; + switch (key_.address_family) { + case ADDRESS_FAMILY_UNSPECIFIED: + query_types.push_back(HostResolver::DnsQueryType::A); + query_types.push_back(HostResolver::DnsQueryType::AAAA); + break; + case ADDRESS_FAMILY_IPV4: + query_types.push_back(HostResolver::DnsQueryType::A); + break; + case ADDRESS_FAMILY_IPV6: + query_types.push_back(HostResolver::DnsQueryType::AAAA); + break; + } + + mdns_task_ = std::make_unique<HostResolverMdnsTask>( + resolver_->GetOrCreateMdnsClient(), key_.hostname, query_types); + mdns_task_->Start( + base::BindOnce(&Job::OnMdnsTaskComplete, base::Unretained(this))); + } + + void OnMdnsTaskComplete(int error) { + DCHECK(is_mdns_running()); + // TODO(crbug.com/846423): Consider adding MDNS-specific logging. + + if (error != OK) { + CompleteRequestsWithError(error); + } else if (ContainsIcannNameCollisionIp(mdns_task_->result_addresses())) { + CompleteRequestsWithError(ERR_ICANN_NAME_COLLISION); + } else { + // MDNS uses a separate cache, so skip saving result to cache. + // TODO(crbug.com/846423): Consider merging caches. + CompleteRequestsWithoutCache(error, mdns_task_->result_addresses()); + } + } + URLRequestContext* url_request_context() override { return resolver_->url_request_context_; } @@ -1881,8 +1934,12 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, } // Performs Job's last rites. Completes all Requests. Deletes this. + // + // If not |allow_cache|, result will not be stored in the host cache, even if + // result would otherwise allow doing so. void CompleteRequests(const HostCache::Entry& entry, - base::TimeDelta ttl) { + base::TimeDelta ttl, + bool allow_cache) { CHECK(resolver_.get()); // This job must be removed from resolver's |jobs_| now to make room for a @@ -1894,6 +1951,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, if (is_running()) { proc_task_ = nullptr; KillDnsTask(); + mdns_task_ = nullptr; // Signal dispatcher that a slot has opened. resolver_->dispatcher_->OnJobFinished(); @@ -1923,7 +1981,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, bool did_complete = (entry.error() != ERR_NETWORK_CHANGED) && (entry.error() != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE); - if (did_complete) + if (did_complete && allow_cache) resolver_->CacheResult(key_, entry, ttl); RecordJobHistograms(entry.error()); @@ -1955,11 +2013,17 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, } } + void CompleteRequestsWithoutCache(int error, const AddressList& addresses) { + CompleteRequests( + MakeCacheEntry(error, addresses, HostCache::Entry::SOURCE_UNKNOWN), + base::TimeDelta(), false /* allow_cache */); + } + // Convenience wrapper for CompleteRequests in case of failure. void CompleteRequestsWithError(int net_error) { CompleteRequests(HostCache::Entry(net_error, AddressList(), HostCache::Entry::SOURCE_UNKNOWN), - base::TimeDelta()); + base::TimeDelta(), true /* allow_cache */); } RequestPriority priority() const override { @@ -1973,6 +2037,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, bool is_dns_running() const { return !!dns_task_; } + bool is_mdns_running() const { return !!mdns_task_; } + bool is_proc_running() const { return !!proc_task_; } base::WeakPtr<HostResolverImpl> resolver_; @@ -2006,6 +2072,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, // Resolves the host using a DnsTransaction. std::unique_ptr<DnsTask> dns_task_; + // Resolves the host using MDnsClient. + std::unique_ptr<HostResolverMdnsTask> mdns_task_; + // All Requests waiting for the result of this Job. Some can be canceled. base::LinkedList<RequestImpl> requests_; @@ -2111,8 +2180,9 @@ void HostResolverImpl::SetDnsClient(std::unique_ptr<DnsClient> dns_client) { num_dns_failures_ < kMaximumDnsFailures) { DnsConfig dns_config; NetworkChangeNotifier::GetDnsConfig(&dns_config); - dns_config.dns_over_https_servers = dns_over_https_servers_; - dns_client_->SetConfig(dns_config); + DnsConfig overridden_config = + dns_config_overrides_.ApplyOverrides(dns_config); + dns_client_->SetConfig(overridden_config); num_dns_failures_ = 0; if (dns_client_->GetConfig()) UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); @@ -2262,34 +2332,29 @@ bool HostResolverImpl::GetNoIPv6OnWifi() { return assume_ipv6_failure_on_wifi_; } -void HostResolverImpl::SetRequestContext(URLRequestContext* context) { - if (context != url_request_context_) { - url_request_context_ = context; - } -} +void HostResolverImpl::SetDnsConfigOverrides( + const DnsConfigOverrides& overrides) { + if (dns_config_overrides_ == overrides) + return; -void HostResolverImpl::AddDnsOverHttpsServer(std::string uri_template, - bool use_post) { - dns_over_https_servers_.emplace_back(uri_template, use_post); + dns_config_overrides_ = overrides; if (dns_client_.get() && dns_client_->GetConfig()) UpdateDNSConfig(true); } -void HostResolverImpl::ClearDnsOverHttpsServers() { - if (dns_over_https_servers_.size() == 0) - return; - - dns_over_https_servers_.clear(); - - if (dns_client_.get() && dns_client_->GetConfig()) - UpdateDNSConfig(true); +void HostResolverImpl::SetRequestContext(URLRequestContext* context) { + if (context != url_request_context_) { + url_request_context_ = context; + } } const std::vector<DnsConfig::DnsOverHttpsServerConfig>* HostResolverImpl::GetDnsOverHttpsServersForTesting() const { - if (dns_over_https_servers_.empty()) + if (!dns_config_overrides_.dns_over_https_servers || + dns_config_overrides_.dns_over_https_servers.value().empty()) { return nullptr; - return &dns_over_https_servers_; + } + return &dns_config_overrides_.dns_over_https_servers.value(); } void HostResolverImpl::SetTickClockForTesting( @@ -2312,6 +2377,17 @@ void HostResolverImpl::SetHaveOnlyLoopbackAddresses(bool result) { } } +void HostResolverImpl::SetMdnsSocketFactoryForTesting( + std::unique_ptr<MDnsSocketFactory> socket_factory) { + DCHECK(!mdns_client_); + mdns_socket_factory_ = std::move(socket_factory); +} + +void HostResolverImpl::SetMdnsClientForTesting( + std::unique_ptr<MDnsClient> client) { + mdns_client_ = std::move(client); +} + void HostResolverImpl::SetTaskRunnerForTesting( scoped_refptr<base::TaskRunner> task_runner) { proc_task_runner_ = std::move(task_runner); @@ -2322,6 +2398,11 @@ int HostResolverImpl::Resolve(RequestImpl* request) { DCHECK(!request->job()); // Request may only be resolved once. DCHECK(!request->complete()); + // MDNS requests do not support skipping cache. + // TODO(crbug.com/846423): Either add support for skipping the MDNS cache, or + // merge to use the normal host cache for MDNS requests. + DCHECK(request->parameters().source != HostResolverSource::MULTICAST_DNS || + request->parameters().allow_cached_response); request->set_request_time(tick_clock_->NowTicks()); @@ -2523,16 +2604,14 @@ bool HostResolverImpl::ServeFromHosts(const Key& key, // necessary. if (key.address_family == ADDRESS_FAMILY_IPV6 || key.address_family == ADDRESS_FAMILY_UNSPECIFIED) { - DnsHosts::const_iterator it = hosts.find( - DnsHostsKey(hostname, ADDRESS_FAMILY_IPV6)); + auto it = hosts.find(DnsHostsKey(hostname, ADDRESS_FAMILY_IPV6)); if (it != hosts.end()) addresses->push_back(IPEndPoint(it->second, host_port)); } if (key.address_family == ADDRESS_FAMILY_IPV4 || key.address_family == ADDRESS_FAMILY_UNSPECIFIED) { - DnsHosts::const_iterator it = hosts.find( - DnsHostsKey(hostname, ADDRESS_FAMILY_IPV4)); + auto it = hosts.find(DnsHostsKey(hostname, ADDRESS_FAMILY_IPV4)); if (it != hosts.end()) addresses->push_back(IPEndPoint(it->second, host_port)); } @@ -2797,6 +2876,9 @@ void HostResolverImpl::UpdateDNSConfig(bool config_changed) { // TODO(szym): Remove once http://crbug.com/137914 is resolved. received_dns_config_ = dns_config.IsValid(); + + dns_config = dns_config_overrides_.ApplyOverrides(dns_config); + // Conservatively assume local IPv6 is needed when DnsConfig is not valid. use_local_ipv6_ = !dns_config.IsValid() || dns_config.use_local_ipv6; @@ -2809,7 +2891,6 @@ void HostResolverImpl::UpdateDNSConfig(bool config_changed) { // wasn't already a DnsConfig or it's the same one. DCHECK(config_changed || !dns_client_->GetConfig() || dns_client_->GetConfig()->Equals(dns_config)); - dns_config.dns_over_https_servers = dns_over_https_servers_; dns_client_->SetConfig(dns_config); if (dns_client_->GetConfig()) UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); @@ -2867,6 +2948,25 @@ void HostResolverImpl::OnDnsTaskResolve(int net_error) { std::abs(net_error)); } +MDnsClient* HostResolverImpl::GetOrCreateMdnsClient() { +#if BUILDFLAG(ENABLE_MDNS) + if (!mdns_client_) { + if (!mdns_socket_factory_) + mdns_socket_factory_ = std::make_unique<MDnsSocketFactoryImpl>(net_log_); + + mdns_client_ = MDnsClient::CreateDefault(); + mdns_client_->StartListening(mdns_socket_factory_.get()); + } + + DCHECK(mdns_client_->IsListening()); + return mdns_client_.get(); +#else + // Should not request MDNS resoltuion unless MDNS is enabled. + NOTREACHED(); + return nullptr; +#endif +} + HostResolverImpl::RequestImpl::~RequestImpl() { if (job_) job_->CancelRequest(this); diff --git a/chromium/net/dns/host_resolver_impl.h b/chromium/net/dns/host_resolver_impl.h index 955241e6956..6bbbde494e3 100644 --- a/chromium/net/dns/host_resolver_impl.h +++ b/chromium/net/dns/host_resolver_impl.h @@ -19,7 +19,8 @@ #include "base/timer/timer.h" #include "net/base/completion_once_callback.h" #include "net/base/network_change_notifier.h" -#include "net/dns/dns_config_service.h" +#include "net/dns/dns_config.h" +#include "net/dns/dns_config_overrides.h" #include "net/dns/host_cache.h" #include "net/dns/host_resolver.h" #include "net/dns/host_resolver_proc.h" @@ -35,6 +36,8 @@ namespace net { class AddressList; class DnsClient; class IPAddress; +class MDnsClient; +class MDnsSocketFactory; class NetLog; class NetLogWithSource; @@ -171,9 +174,9 @@ class NET_EXPORT HostResolverImpl void SetNoIPv6OnWifi(bool no_ipv6_on_wifi) override; bool GetNoIPv6OnWifi() override; + void SetDnsConfigOverrides(const DnsConfigOverrides& overrides) override; + void SetRequestContext(URLRequestContext* request_context) override; - void AddDnsOverHttpsServer(std::string uri_template, bool use_post) override; - void ClearDnsOverHttpsServers() override; const std::vector<DnsConfig::DnsOverHttpsServerConfig>* GetDnsOverHttpsServersForTesting() const override; @@ -187,6 +190,10 @@ class NET_EXPORT HostResolverImpl // Only allowed when the queue is empty. void SetMaxQueuedJobsForTesting(size_t value); + void SetMdnsSocketFactoryForTesting( + std::unique_ptr<MDnsSocketFactory> socket_factory); + void SetMdnsClientForTesting(std::unique_ptr<MDnsClient> client); + protected: // Callback from HaveOnlyLoopbackAddresses probe. void SetHaveOnlyLoopbackAddresses(bool result); @@ -345,6 +352,8 @@ class NET_EXPORT HostResolverImpl // and resulted in |net_error|. void OnDnsTaskResolve(int net_error); + MDnsClient* GetOrCreateMdnsClient(); + // Allows the tests to catch slots leaking out of the dispatcher. One // HostResolverImpl::Job could occupy multiple PrioritizedDispatcher job // slots. @@ -355,6 +364,11 @@ class NET_EXPORT HostResolverImpl // Cache of host resolution results. std::unique_ptr<HostCache> cache_; + // Used for multicast DNS tasks. Created on first use using + // GetOrCreateMndsClient(). + std::unique_ptr<MDnsSocketFactory> mdns_socket_factory_; + std::unique_ptr<MDnsClient> mdns_client_; + // Map from HostCache::Key to a Job. JobMap jobs_; @@ -376,6 +390,10 @@ class NET_EXPORT HostResolverImpl // to measure performance of DnsConfigService: http://crbug.com/125599 bool received_dns_config_; + // Overrides or adds to DNS configuration read from the system for DnsClient + // resolution. + DnsConfigOverrides dns_config_overrides_; + // Number of consecutive failures of DnsTask, counted when fallback succeeds. unsigned num_dns_failures_; @@ -401,7 +419,6 @@ class NET_EXPORT HostResolverImpl scoped_refptr<base::TaskRunner> proc_task_runner_; URLRequestContext* url_request_context_; - std::vector<DnsConfig::DnsOverHttpsServerConfig> dns_over_https_servers_; // Shared tick clock, overridden for testing. const base::TickClock* tick_clock_; diff --git a/chromium/net/dns/host_resolver_impl_unittest.cc b/chromium/net/dns/host_resolver_impl_unittest.cc index fdb1a0ab32f..fb433d78565 100644 --- a/chromium/net/dns/host_resolver_impl_unittest.cc +++ b/chromium/net/dns/host_resolver_impl_unittest.cc @@ -34,8 +34,11 @@ #include "net/base/mock_network_change_notifier.h" #include "net/base/net_errors.h" #include "net/dns/dns_client.h" +#include "net/dns/dns_config.h" #include "net/dns/dns_test_util.h" #include "net/dns/mock_host_resolver.h" +#include "net/dns/mock_mdns_client.h" +#include "net/dns/mock_mdns_socket_factory.h" #include "net/log/net_log_event_type.h" #include "net/log/net_log_source_type.h" #include "net/log/net_log_with_source.h" @@ -47,7 +50,11 @@ using net::test::IsError; using net::test::IsOk; +using ::testing::_; +using ::testing::Between; +using ::testing::ByMove; using ::testing::NotNull; +using ::testing::Return; namespace net { @@ -3012,6 +3019,274 @@ TEST_F(HostResolverImplTest, IsSpeculative_ResolveHost) { EXPECT_EQ(1u, proc_->GetCaptureList().size()); // No increase. } +#if BUILDFLAG(ENABLE_MDNS) +const uint8_t kMdnsResponseA[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // "myhello.local." + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, 0x00, 0x10, // TTL is 16 (seconds) + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x01, 0x02, 0x03, 0x04, // 1.2.3.4 +}; + +const uint8_t kMdnsResponseAAAA[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // "myhello.local." + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + + 0x00, 0x1C, // TYPE is AAAA. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, 0x00, 0x10, // TTL is 16 (seconds) + 0x00, 0x10, // RDLENGTH is 16 bytes. + + // 000a:0000:0000:0000:0001:0002:0003:0004 + 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, + 0x00, 0x03, 0x00, 0x04, +}; + +// An MDNS response indicating that the responder owns the hostname, but the +// specific requested type (AAAA) does not exist because the responder only has +// A addresses. +const uint8_t kMdnsResponseNsec[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // "myhello.local." + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + + 0x00, 0x2f, // TYPE is NSEC. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, 0x00, 0x10, // TTL is 16 (seconds) + 0x00, 0x06, // RDLENGTH is 6 bytes. + 0xc0, 0x0c, // Next Domain Name (always pointer back to name in MDNS) + 0x00, // Bitmap block number (always 0 in MDNS) + 0x02, // Bitmap length is 2 + 0x00, 0x08 // A type only +}; + +TEST_F(HostResolverImplTest, Mdns) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(4); + + HostResolver::ResolveHostParameters parameters; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + socket_factory_ptr->SimulateReceive(kMdnsResponseA, sizeof(kMdnsResponseA)); + socket_factory_ptr->SimulateReceive(kMdnsResponseAAAA, + sizeof(kMdnsResponseAAAA)); + + EXPECT_THAT(response.result_error(), IsOk()); + EXPECT_THAT( + response.request()->GetAddressResults().value().endpoints(), + testing::UnorderedElementsAre( + CreateExpected("1.2.3.4", 80), + CreateExpected("000a:0000:0000:0000:0001:0002:0003:0004", 80))); +} + +TEST_F(HostResolverImplTest, Mdns_AaaaOnly) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(2); + + HostResolver::ResolveHostParameters parameters; + parameters.dns_query_type = HostResolver::DnsQueryType::AAAA; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + socket_factory_ptr->SimulateReceive(kMdnsResponseAAAA, + sizeof(kMdnsResponseAAAA)); + + EXPECT_THAT(response.result_error(), IsOk()); + EXPECT_THAT(response.request()->GetAddressResults().value().endpoints(), + testing::ElementsAre(CreateExpected( + "000a:0000:0000:0000:0001:0002:0003:0004", 80))); +} + +// Test multicast DNS handling of NSEC responses (used for explicit negative +// response). +TEST_F(HostResolverImplTest, Mdns_Nsec) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(2); + + HostResolver::ResolveHostParameters parameters; + parameters.dns_query_type = HostResolver::DnsQueryType::AAAA; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + socket_factory_ptr->SimulateReceive(kMdnsResponseNsec, + sizeof(kMdnsResponseNsec)); + + EXPECT_THAT(response.result_error(), IsError(ERR_NAME_NOT_RESOLVED)); + EXPECT_FALSE(response.request()->GetAddressResults()); +} + +TEST_F(HostResolverImplTest, Mdns_NoResponse) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(4); + + // Add a little bit of extra fudge to the delay to allow reasonable + // flexibility for time > vs >= etc. We don't need to fail the test if we + // timeout at t=6001 instead of t=6000. + base::TimeDelta kSleepFudgeFactor = base::TimeDelta::FromMilliseconds(1); + + // Override the current thread task runner, so we can simulate the passage of + // time to trigger the timeout. + auto test_task_runner = base::MakeRefCounted<base::TestMockTimeTaskRunner>(); + base::ScopedClosureRunner task_runner_override_scoped_cleanup = + base::ThreadTaskRunnerHandle::OverrideForTesting(test_task_runner); + + HostResolver::ResolveHostParameters parameters; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + ASSERT_TRUE(test_task_runner->HasPendingTask()); + test_task_runner->FastForwardBy(MDnsTransaction::kTransactionTimeout + + kSleepFudgeFactor); + + EXPECT_THAT(response.result_error(), IsError(ERR_NAME_NOT_RESOLVED)); + EXPECT_FALSE(response.request()->GetAddressResults()); + + test_task_runner->FastForwardUntilNoTasksRemain(); +} + +// Test for a request for both A and AAAA results where results only exist for +// one type. +TEST_F(HostResolverImplTest, Mdns_PartialResults) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(4); + + // Add a little bit of extra fudge to the delay to allow reasonable + // flexibility for time > vs >= etc. We don't need to fail the test if we + // timeout at t=6001 instead of t=6000. + base::TimeDelta kSleepFudgeFactor = base::TimeDelta::FromMilliseconds(1); + + // Override the current thread task runner, so we can simulate the passage of + // time to trigger the timeout. + auto test_task_runner = base::MakeRefCounted<base::TestMockTimeTaskRunner>(); + base::ScopedClosureRunner task_runner_override_scoped_cleanup = + base::ThreadTaskRunnerHandle::OverrideForTesting(test_task_runner); + + HostResolver::ResolveHostParameters parameters; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + ASSERT_TRUE(test_task_runner->HasPendingTask()); + + socket_factory_ptr->SimulateReceive(kMdnsResponseA, sizeof(kMdnsResponseA)); + test_task_runner->FastForwardBy(MDnsTransaction::kTransactionTimeout + + kSleepFudgeFactor); + + EXPECT_THAT(response.result_error(), IsOk()); + EXPECT_THAT(response.request()->GetAddressResults().value().endpoints(), + testing::ElementsAre(CreateExpected("1.2.3.4", 80))); + + test_task_runner->FastForwardUntilNoTasksRemain(); +} + +TEST_F(HostResolverImplTest, Mdns_Cancel) { + auto socket_factory = std::make_unique<MockMDnsSocketFactory>(); + MockMDnsSocketFactory* socket_factory_ptr = socket_factory.get(); + resolver_->SetMdnsSocketFactoryForTesting(std::move(socket_factory)); + // 2 socket creations for every transaction. + EXPECT_CALL(*socket_factory_ptr, OnSendTo(_)).Times(4); + + HostResolver::ResolveHostParameters parameters; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + response.CancelRequest(); + + socket_factory_ptr->SimulateReceive(kMdnsResponseA, sizeof(kMdnsResponseA)); + socket_factory_ptr->SimulateReceive(kMdnsResponseAAAA, + sizeof(kMdnsResponseAAAA)); + + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(response.complete()); +} + +// Test for a two-transaction query where the first fails to start. The second +// should be cancelled. +TEST_F(HostResolverImplTest, Mdns_PartialFailure) { + // Setup a mock MDnsClient where the first transaction will always return + // |false| immediately on Start(). Second transaction may or may not be + // created, but if it is, Start() not expected to be called because the + // overall request should immediately fail. + auto transaction1 = std::make_unique<MockMDnsTransaction>(); + EXPECT_CALL(*transaction1, Start()).WillOnce(Return(false)); + auto transaction2 = std::make_unique<MockMDnsTransaction>(); + EXPECT_CALL(*transaction2, Start()).Times(0); + + auto client = std::make_unique<MockMDnsClient>(); + EXPECT_CALL(*client, CreateTransaction(_, _, _, _)) + .Times(Between(1, 2)) // Second transaction optionally created. + .WillOnce(Return(ByMove(std::move(transaction1)))) + .WillOnce(Return(ByMove(std::move(transaction2)))); + EXPECT_CALL(*client, IsListening()).WillRepeatedly(Return(true)); + resolver_->SetMdnsClientForTesting(std::move(client)); + + HostResolver::ResolveHostParameters parameters; + parameters.source = HostResolverSource::MULTICAST_DNS; + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("myhello.local", 80), NetLogWithSource(), parameters)); + + EXPECT_THAT(response.result_error(), IsError(ERR_FAILED)); + EXPECT_FALSE(response.request()->GetAddressResults()); +} +#endif // BUILDFLAG(ENABLE_MDNS) + DnsConfig CreateValidDnsConfig() { IPAddress dns_ip(192, 168, 1, 0); DnsConfig config; @@ -5267,7 +5542,10 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerAfterConfig) { resolver_->SetDnsClientEnabled(true); std::string server("https://dnsserver.example.net/dns-query{?dns}"); - resolver_->AddDnsOverHttpsServer(server, true); + DnsConfigOverrides overrides; + overrides.dns_over_https_servers.emplace( + {DnsConfig::DnsOverHttpsServerConfig(server, true)}); + resolver_->SetDnsConfigOverrides(overrides); base::DictionaryValue* config; auto value = resolver_->GetDnsConfigAsValue(); @@ -5297,7 +5575,10 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerBeforeConfig) { CreateSerialResolver(); // To guarantee order of resolutions. resolver_->SetDnsClientEnabled(true); std::string server("https://dnsserver.example.net/dns-query{?dns}"); - resolver_->AddDnsOverHttpsServer(server, true); + DnsConfigOverrides overrides; + overrides.dns_over_https_servers.emplace( + {DnsConfig::DnsOverHttpsServerConfig(server, true)}); + resolver_->SetDnsConfigOverrides(overrides); notifier.mock_network_change_notifier()->SetConnectionType( NetworkChangeNotifier::CONNECTION_WIFI); @@ -5330,7 +5611,10 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerBeforeClient) { test::ScopedMockNetworkChangeNotifier notifier; CreateSerialResolver(); // To guarantee order of resolutions. std::string server("https://dnsserver.example.net/dns-query{?dns}"); - resolver_->AddDnsOverHttpsServer(server, true); + DnsConfigOverrides overrides; + overrides.dns_over_https_servers.emplace( + {DnsConfig::DnsOverHttpsServerConfig(server, true)}); + resolver_->SetDnsConfigOverrides(overrides); notifier.mock_network_change_notifier()->SetConnectionType( NetworkChangeNotifier::CONNECTION_WIFI); @@ -5365,7 +5649,10 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerAndThenRemove) { test::ScopedMockNetworkChangeNotifier notifier; CreateSerialResolver(); // To guarantee order of resolutions. std::string server("https://dns.example.com/"); - resolver_->AddDnsOverHttpsServer(server, true); + DnsConfigOverrides overrides; + overrides.dns_over_https_servers.emplace( + {DnsConfig::DnsOverHttpsServerConfig(server, true)}); + resolver_->SetDnsConfigOverrides(overrides); notifier.mock_network_change_notifier()->SetConnectionType( NetworkChangeNotifier::CONNECTION_WIFI); @@ -5394,7 +5681,7 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerAndThenRemove) { EXPECT_TRUE(server_method->GetString("server_template", &server_template)); EXPECT_EQ(server_template, server); - resolver_->ClearDnsOverHttpsServers(); + resolver_->SetDnsConfigOverrides(DnsConfigOverrides()); value = resolver_->GetDnsConfigAsValue(); EXPECT_TRUE(value); if (!value) @@ -5407,4 +5694,183 @@ TEST_F(HostResolverImplDnsTest, AddDnsOverHttpsServerAndThenRemove) { EXPECT_EQ(doh_servers->GetSize(), 0u); } +TEST_F(HostResolverImplDnsTest, SetDnsConfigOverrides) { + DnsConfig original_config = CreateValidDnsConfig(); + ChangeDnsConfig(original_config); + + // Confirm pre-override state. + ASSERT_TRUE(original_config.Equals(*dns_client_->GetConfig())); + + DnsConfigOverrides overrides; + const std::vector<IPEndPoint> nameservers = { + CreateExpected("192.168.0.1", 92)}; + overrides.nameservers = nameservers; + const std::vector<std::string> search = {"str"}; + overrides.search = search; + const DnsHosts hosts = { + {DnsHostsKey("host", ADDRESS_FAMILY_IPV4), IPAddress(192, 168, 1, 1)}}; + overrides.hosts = hosts; + overrides.append_to_multi_label_name = false; + overrides.randomize_ports = true; + const int ndots = 5; + overrides.ndots = ndots; + const base::TimeDelta timeout = base::TimeDelta::FromSeconds(10); + overrides.timeout = timeout; + const int attempts = 20; + overrides.attempts = attempts; + overrides.rotate = true; + overrides.use_local_ipv6 = true; + const std::vector<DnsConfig::DnsOverHttpsServerConfig> + dns_over_https_servers = { + DnsConfig::DnsOverHttpsServerConfig("dns.example.com", true)}; + overrides.dns_over_https_servers = dns_over_https_servers; + + resolver_->SetDnsConfigOverrides(overrides); + + const DnsConfig* overridden_config = dns_client_->GetConfig(); + EXPECT_EQ(nameservers, overridden_config->nameservers); + EXPECT_EQ(search, overridden_config->search); + EXPECT_EQ(hosts, overridden_config->hosts); + EXPECT_FALSE(overridden_config->append_to_multi_label_name); + EXPECT_TRUE(overridden_config->randomize_ports); + EXPECT_EQ(ndots, overridden_config->ndots); + EXPECT_EQ(timeout, overridden_config->timeout); + EXPECT_EQ(attempts, overridden_config->attempts); + EXPECT_TRUE(overridden_config->rotate); + EXPECT_TRUE(overridden_config->use_local_ipv6); + EXPECT_EQ(dns_over_https_servers, overridden_config->dns_over_https_servers); +} + +TEST_F(HostResolverImplDnsTest, SetDnsConfigOverrides_PartialOverride) { + DnsConfig original_config = CreateValidDnsConfig(); + ChangeDnsConfig(original_config); + + // Confirm pre-override state. + ASSERT_TRUE(original_config.Equals(*dns_client_->GetConfig())); + + DnsConfigOverrides overrides; + const std::vector<IPEndPoint> nameservers = { + CreateExpected("192.168.0.2", 192)}; + overrides.nameservers = nameservers; + overrides.rotate = true; + + resolver_->SetDnsConfigOverrides(overrides); + + const DnsConfig* overridden_config = dns_client_->GetConfig(); + EXPECT_EQ(nameservers, overridden_config->nameservers); + EXPECT_EQ(original_config.search, overridden_config->search); + EXPECT_EQ(original_config.hosts, overridden_config->hosts); + EXPECT_TRUE(overridden_config->append_to_multi_label_name); + EXPECT_FALSE(overridden_config->randomize_ports); + EXPECT_EQ(original_config.ndots, overridden_config->ndots); + EXPECT_EQ(original_config.timeout, overridden_config->timeout); + EXPECT_EQ(original_config.attempts, overridden_config->attempts); + EXPECT_TRUE(overridden_config->rotate); + EXPECT_FALSE(overridden_config->use_local_ipv6); + EXPECT_EQ(original_config.dns_over_https_servers, + overridden_config->dns_over_https_servers); +} + +// Test that overridden configs are reapplied over a changed underlying system +// config. +TEST_F(HostResolverImplDnsTest, SetDnsConfigOverrides_NewConfig) { + DnsConfig original_config = CreateValidDnsConfig(); + ChangeDnsConfig(original_config); + + // Confirm pre-override state. + ASSERT_TRUE(original_config.Equals(*dns_client_->GetConfig())); + + DnsConfigOverrides overrides; + const std::vector<IPEndPoint> nameservers = { + CreateExpected("192.168.0.2", 192)}; + overrides.nameservers = nameservers; + + resolver_->SetDnsConfigOverrides(overrides); + ASSERT_EQ(nameservers, dns_client_->GetConfig()->nameservers); + + DnsConfig new_config = original_config; + new_config.attempts = 103; + ASSERT_NE(nameservers, new_config.nameservers); + ChangeDnsConfig(new_config); + + const DnsConfig* overridden_config = dns_client_->GetConfig(); + EXPECT_EQ(nameservers, overridden_config->nameservers); + EXPECT_EQ(new_config.attempts, overridden_config->attempts); +} + +TEST_F(HostResolverImplDnsTest, SetDnsConfigOverrides_ClearOverrides) { + DnsConfig original_config = CreateValidDnsConfig(); + ChangeDnsConfig(original_config); + + DnsConfigOverrides overrides; + overrides.attempts = 245; + resolver_->SetDnsConfigOverrides(overrides); + + ASSERT_FALSE(original_config.Equals(*dns_client_->GetConfig())); + + resolver_->SetDnsConfigOverrides(DnsConfigOverrides()); + EXPECT_TRUE(original_config.Equals(*dns_client_->GetConfig())); +} + +// Test that in-progress queries are cancelled on applying new DNS config +// overrides, same as receiving a new DnsConfig from the system. +TEST_F(HostResolverImplDnsTest, CancelQueriesOnSettingOverrides) { + ChangeDnsConfig(CreateValidDnsConfig()); + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("ok", 80), NetLogWithSource(), base::nullopt)); + ASSERT_FALSE(response.complete()); + + DnsConfigOverrides overrides; + overrides.attempts = 123; + resolver_->SetDnsConfigOverrides(overrides); + + EXPECT_THAT(response.result_error(), IsError(ERR_NETWORK_CHANGED)); +} + +// Queries should not be cancelled if equal overrides are set. +TEST_F(HostResolverImplDnsTest, CancelQueriesOnSettingOverrides_SameOverrides) { + ChangeDnsConfig(CreateValidDnsConfig()); + DnsConfigOverrides overrides; + overrides.attempts = 123; + resolver_->SetDnsConfigOverrides(overrides); + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("ok", 80), NetLogWithSource(), base::nullopt)); + ASSERT_FALSE(response.complete()); + + resolver_->SetDnsConfigOverrides(overrides); + + EXPECT_THAT(response.result_error(), IsOk()); +} + +// Test that in-progress queries are cancelled on clearing DNS config overrides, +// same as receiving a new DnsConfig from the system. +TEST_F(HostResolverImplDnsTest, CancelQueriesOnClearingOverrides) { + ChangeDnsConfig(CreateValidDnsConfig()); + DnsConfigOverrides overrides; + overrides.attempts = 123; + resolver_->SetDnsConfigOverrides(overrides); + + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("ok", 80), NetLogWithSource(), base::nullopt)); + ASSERT_FALSE(response.complete()); + + resolver_->SetDnsConfigOverrides(DnsConfigOverrides()); + + EXPECT_THAT(response.result_error(), IsError(ERR_NETWORK_CHANGED)); +} + +// Queries should not be cancelled on clearing overrides if there were not any +// overrides. +TEST_F(HostResolverImplDnsTest, CancelQueriesOnClearingOverrides_NoOverrides) { + ChangeDnsConfig(CreateValidDnsConfig()); + ResolveHostResponseHelper response(resolver_->CreateRequest( + HostPortPair("ok", 80), NetLogWithSource(), base::nullopt)); + ASSERT_FALSE(response.complete()); + + resolver_->SetDnsConfigOverrides(DnsConfigOverrides()); + + EXPECT_THAT(response.result_error(), IsOk()); +} + } // namespace net diff --git a/chromium/net/dns/host_resolver_mdns_task.cc b/chromium/net/dns/host_resolver_mdns_task.cc new file mode 100644 index 00000000000..a2f7ba1be98 --- /dev/null +++ b/chromium/net/dns/host_resolver_mdns_task.cc @@ -0,0 +1,208 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_resolver_mdns_task.h" + +#include <algorithm> +#include <utility> + +#include "base/logging.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/record_parsed.h" +#include "net/dns/record_rdata.h" + +namespace net { + +class HostResolverMdnsTask::Transaction { + public: + Transaction(HostResolver::DnsQueryType query_type, HostResolverMdnsTask* task) + : query_type_(query_type), result_(ERR_IO_PENDING), task_(task) {} + + void Start() { + DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_); + + // Should not be completed or running yet. + DCHECK_EQ(ERR_IO_PENDING, result_); + DCHECK(!async_transaction_); + + uint16_t rrtype; + switch (query_type_) { + case net::HostResolver::DnsQueryType::A: + rrtype = net::dns_protocol::kTypeA; + break; + case net::HostResolver::DnsQueryType::AAAA: + rrtype = net::dns_protocol::kTypeAAAA; + break; + default: + // Type not supported for MDNS. + NOTREACHED(); + return; + } + + // TODO(crbug.com/846423): Use |allow_cached_response| to set the + // QUERY_CACHE flag or not. + int flags = MDnsTransaction::SINGLE_RESULT | MDnsTransaction::QUERY_CACHE | + MDnsTransaction::QUERY_NETWORK; + // If |this| is destroyed, destruction of |internal_transaction_| should + // cancel and prevent invocation of OnComplete. + std::unique_ptr<MDnsTransaction> inner_transaction = + task_->mdns_client_->CreateTransaction( + rrtype, task_->hostname_, flags, + base::BindRepeating(&HostResolverMdnsTask::Transaction::OnComplete, + base::Unretained(this))); + + // Side effect warning: Start() may finish and invoke callbacks inline. + bool start_result = inner_transaction->Start(); + + if (!start_result) + task_->CompleteWithResult(ERR_FAILED, true /* post_needed */); + else if (result_ == ERR_IO_PENDING) + async_transaction_ = std::move(inner_transaction); + } + + bool IsDone() const { return result_ != ERR_IO_PENDING; } + bool IsError() const { + return IsDone() && result_ != OK && result_ != ERR_NAME_NOT_RESOLVED; + } + int result() const { return result_; } + + void Cancel() { + DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_); + DCHECK_EQ(ERR_IO_PENDING, result_); + + result_ = ERR_FAILED; + async_transaction_ = nullptr; + } + + private: + void OnComplete(MDnsTransaction::Result result, const RecordParsed* parsed) { + DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_); + DCHECK_EQ(ERR_IO_PENDING, result_); + + switch (result) { + case MDnsTransaction::RESULT_RECORD: + result_ = OK; + break; + case MDnsTransaction::RESULT_NO_RESULTS: + case MDnsTransaction::RESULT_NSEC: + result_ = ERR_NAME_NOT_RESOLVED; + break; + default: + // No other results should be possible with the request flags used. + NOTREACHED(); + } + + if (result_ == net::OK) { + switch (query_type_) { + case net::HostResolver::DnsQueryType::A: + task_->result_addresses_.push_back( + IPEndPoint(parsed->rdata<net::ARecordRdata>()->address(), 0)); + break; + case net::HostResolver::DnsQueryType::AAAA: + task_->result_addresses_.push_back( + IPEndPoint(parsed->rdata<net::AAAARecordRdata>()->address(), 0)); + break; + default: + NOTREACHED(); + } + } + + // If we don't have a saved async_transaction, it means OnComplete was + // invoked inline in MDnsTransaction::Start. Callbacks will need to be + // invoked via post. + task_->CheckCompletion(!async_transaction_); + } + + const HostResolver::DnsQueryType query_type_; + + // ERR_IO_PENDING until transaction completes (or is cancelled). + int result_; + + // Not saved until MDnsTransaction::Start completes to differentiate inline + // completion. + std::unique_ptr<MDnsTransaction> async_transaction_; + + // Back pointer. Expected to destroy |this| before destroying itself. + HostResolverMdnsTask* const task_; +}; + +HostResolverMdnsTask::HostResolverMdnsTask( + MDnsClient* mdns_client, + const std::string& hostname, + const std::vector<HostResolver::DnsQueryType>& query_types) + : mdns_client_(mdns_client), hostname_(hostname), weak_ptr_factory_(this) { + DCHECK(!query_types.empty()); + for (HostResolver::DnsQueryType query_type : query_types) { + transactions_.emplace_back(query_type, this); + } +} + +HostResolverMdnsTask::~HostResolverMdnsTask() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + transactions_.clear(); +} + +void HostResolverMdnsTask::Start(CompletionOnceCallback completion_callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + DCHECK(!completion_callback_); + + completion_callback_ = std::move(completion_callback); + + for (auto& transaction : transactions_) { + // Only start transaction if it is not already marked done. A transaction + // could be marked done before starting if it is preemptively canceled by + // a previously started transaction finishing with an error. + if (!transaction.IsDone()) + transaction.Start(); + } +} + +void HostResolverMdnsTask::CheckCompletion(bool post_needed) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + // Finish immediately if any transactions completed with an error. + auto found_error = + std::find_if(transactions_.begin(), transactions_.end(), + [](const Transaction& t) { return t.IsError(); }); + if (found_error != transactions_.end()) { + CompleteWithResult(found_error->result(), post_needed); + return; + } + + if (std::all_of(transactions_.begin(), transactions_.end(), + [](const Transaction& t) { return t.IsDone(); })) { + // Task is overall successful if any of the transactions found results. + int result = result_addresses_.empty() ? ERR_NAME_NOT_RESOLVED : OK; + + CompleteWithResult(result, post_needed); + return; + } +} + +void HostResolverMdnsTask::CompleteWithResult(int result, bool post_needed) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + // Cancel any incomplete async transactions. + for (auto& transaction : transactions_) { + if (!transaction.IsDone()) + transaction.Cancel(); + } + + if (post_needed) { + base::SequencedTaskRunnerHandle::Get()->PostTask( + FROM_HERE, + base::BindOnce( + [](base::WeakPtr<HostResolverMdnsTask> task, int result) { + if (task) + std::move(task->completion_callback_).Run(result); + }, + weak_ptr_factory_.GetWeakPtr(), result)); + } else { + std::move(completion_callback_).Run(result); + } +} + +} // namespace net diff --git a/chromium/net/dns/host_resolver_mdns_task.h b/chromium/net/dns/host_resolver_mdns_task.h new file mode 100644 index 00000000000..eebf5ffcf7d --- /dev/null +++ b/chromium/net/dns/host_resolver_mdns_task.h @@ -0,0 +1,66 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_HOST_RESOLVER_MDNS_TASK_H_ +#define NET_DNS_HOST_RESOLVER_MDNS_TASK_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "base/containers/unique_ptr_adapters.h" +#include "base/macros.h" +#include "base/memory/weak_ptr.h" +#include "base/sequence_checker.h" +#include "net/base/completion_once_callback.h" +#include "net/dns/host_resolver.h" +#include "net/dns/mdns_client.h" + +namespace net { + +// Representation of a single HostResolverImpl::Job task to resolve the hostname +// using multicast DNS transactions. Destruction cancels the task and prevents +// any callbacks from being invoked. +class HostResolverMdnsTask { + public: + // |mdns_client| must outlive |this|. + HostResolverMdnsTask( + MDnsClient* mdns_client, + const std::string& hostname, + const std::vector<HostResolver::DnsQueryType>& query_types); + ~HostResolverMdnsTask(); + + // Starts the task. |completion_callback| will be called asynchronously with + // results. + // + // Should only be called once. + void Start(CompletionOnceCallback completion_callback); + + const AddressList& result_addresses() { return result_addresses_; } + + private: + class Transaction; + + void CheckCompletion(bool post_needed); + void CompleteWithResult(int result, bool post_needed); + + MDnsClient* const mdns_client_; + + const std::string hostname_; + + AddressList result_addresses_; + std::vector<Transaction> transactions_; + + CompletionOnceCallback completion_callback_; + + SEQUENCE_CHECKER(sequence_checker_); + + base::WeakPtrFactory<HostResolverMdnsTask> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(HostResolverMdnsTask); +}; + +} // namespace net + +#endif // NET_DNS_HOST_RESOLVER_MDNS_TASK_H_ diff --git a/chromium/net/dns/host_resolver_source.h b/chromium/net/dns/host_resolver_source.h index 4985dfc1d9c..d873c074173 100644 --- a/chromium/net/dns/host_resolver_source.h +++ b/chromium/net/dns/host_resolver_source.h @@ -5,6 +5,8 @@ #ifndef NET_DNS_HOST_RESOLVER_SOURCE_H_ #define NET_DNS_HOST_RESOLVER_SOURCE_H_ +namespace net { + // Enumeration to specify the allowed results source for HostResolver // requests. enum class HostResolverSource { @@ -19,7 +21,10 @@ enum class HostResolverSource { // Results will only come from DNS queries. DNS, - // TODO(crbug.com/846423): Add MDNS support. + // Results will only come from Multicast DNS queries. + MULTICAST_DNS, }; +} // namespace net + #endif // NET_DNS_HOST_RESOLVER_SOURCE_H_ diff --git a/chromium/net/dns/mapped_host_resolver.cc b/chromium/net/dns/mapped_host_resolver.cc index 8a237dda0d8..24cfb5d8925 100644 --- a/chromium/net/dns/mapped_host_resolver.cc +++ b/chromium/net/dns/mapped_host_resolver.cc @@ -114,6 +114,20 @@ bool MappedHostResolver::GetNoIPv6OnWifi() { return impl_->GetNoIPv6OnWifi(); } +void MappedHostResolver::SetDnsConfigOverrides( + const DnsConfigOverrides& overrides) { + impl_->SetDnsConfigOverrides(overrides); +} + +void MappedHostResolver::SetRequestContext(URLRequestContext* request_context) { + impl_->SetRequestContext(request_context); +} + +const std::vector<DnsConfig::DnsOverHttpsServerConfig>* +MappedHostResolver::GetDnsOverHttpsServersForTesting() const { + return impl_->GetDnsOverHttpsServersForTesting(); +} + int MappedHostResolver::ApplyRules(RequestInfo* info) const { HostPortPair host_port(info->host_port_pair()); if (rules_.RewriteHost(&host_port)) { diff --git a/chromium/net/dns/mapped_host_resolver.h b/chromium/net/dns/mapped_host_resolver.h index 88edc985454..0d9633365b1 100644 --- a/chromium/net/dns/mapped_host_resolver.h +++ b/chromium/net/dns/mapped_host_resolver.h @@ -7,10 +7,12 @@ #include <memory> #include <string> +#include <vector> #include "net/base/completion_once_callback.h" #include "net/base/host_mapping_rules.h" #include "net/base/net_export.h" +#include "net/dns/dns_config.h" #include "net/dns/host_resolver.h" namespace net { @@ -65,15 +67,17 @@ class NET_EXPORT MappedHostResolver : public HostResolver { HostCache::EntryStaleness* stale_info, const NetLogWithSource& source_net_log) override; void SetDnsClientEnabled(bool enabled) override; - HostCache* GetHostCache() override; bool HasCached(base::StringPiece hostname, HostCache::Entry::Source* source_out, HostCache::EntryStaleness* stale_out) const override; - std::unique_ptr<base::Value> GetDnsConfigAsValue() const override; void SetNoIPv6OnWifi(bool no_ipv6_on_wifi) override; bool GetNoIPv6OnWifi() override; + void SetDnsConfigOverrides(const DnsConfigOverrides& overrides) override; + void SetRequestContext(URLRequestContext* request_context) override; + const std::vector<DnsConfig::DnsOverHttpsServerConfig>* + GetDnsOverHttpsServersForTesting() const override; private: class AlwaysErrorRequestImpl; diff --git a/chromium/net/dns/mdns_cache.cc b/chromium/net/dns/mdns_cache.cc index 5fe78914407..5ab62222315 100644 --- a/chromium/net/dns/mdns_cache.cc +++ b/chromium/net/dns/mdns_cache.cc @@ -55,7 +55,7 @@ MDnsCache::MDnsCache() = default; MDnsCache::~MDnsCache() = default; const RecordParsed* MDnsCache::LookupKey(const Key& key) { - RecordMap::iterator found = mdns_cache_.find(key); + auto found = mdns_cache_.find(key); if (found != mdns_cache_.end()) { return found->second.get(); } @@ -101,8 +101,7 @@ void MDnsCache::CleanupRecords( // impunity. if (now < next_expiration_) return; - for (RecordMap::iterator i = mdns_cache_.begin(); - i != mdns_cache_.end(); ) { + for (auto i = mdns_cache_.begin(); i != mdns_cache_.end();) { base::Time expiration = GetEffectiveExpiration(i->second.get()); if (now >= expiration) { record_removed_callback.Run(i->second.get()); @@ -125,7 +124,7 @@ void MDnsCache::FindDnsRecords(unsigned type, DCHECK(results); results->clear(); - RecordMap::const_iterator i = mdns_cache_.lower_bound(Key(type, name, "")); + auto i = mdns_cache_.lower_bound(Key(type, name, "")); for (; i != mdns_cache_.end(); ++i) { if (i->first.name() != name || (type != 0 && i->first.type() != type)) { @@ -144,7 +143,7 @@ void MDnsCache::FindDnsRecords(unsigned type, std::unique_ptr<const RecordParsed> MDnsCache::RemoveRecord( const RecordParsed* record) { Key key = Key::CreateFor(record); - RecordMap::iterator found = mdns_cache_.find(key); + auto found = mdns_cache_.find(key); if (found != mdns_cache_.end() && found->second.get() == record) { std::unique_ptr<const RecordParsed> result = std::move(found->second); diff --git a/chromium/net/dns/mdns_client.cc b/chromium/net/dns/mdns_client.cc index 864acfda4f2..f69014096cc 100644 --- a/chromium/net/dns/mdns_client.cc +++ b/chromium/net/dns/mdns_client.cc @@ -45,6 +45,9 @@ int Bind(const IPEndPoint& multicast_addr, } // namespace +const base::TimeDelta MDnsTransaction::kTransactionTimeout = + base::TimeDelta::FromSeconds(3); + // static std::unique_ptr<MDnsSocketFactory> MDnsSocketFactory::CreateDefault() { return std::unique_ptr<MDnsSocketFactory>(new MDnsSocketFactoryImpl); diff --git a/chromium/net/dns/mdns_client.h b/chromium/net/dns/mdns_client.h index 6348b1a64d5..8a6213fe00b 100644 --- a/chromium/net/dns/mdns_client.h +++ b/chromium/net/dns/mdns_client.h @@ -12,6 +12,7 @@ #include <vector> #include "base/callback.h" +#include "base/time/time.h" #include "net/base/ip_endpoint.h" #include "net/base/net_export.h" #include "net/dns/dns_query.h" @@ -32,6 +33,8 @@ class RecordParsed; // time out after a reasonable number of seconds. class NET_EXPORT MDnsTransaction { public: + static const base::TimeDelta kTransactionTimeout; + // Used to signify what type of result the transaction has received. enum Result { // Passed whenever a record is found. diff --git a/chromium/net/dns/mdns_client_impl.cc b/chromium/net/dns/mdns_client_impl.cc index d505c1a3180..adc5e0f85d4 100644 --- a/chromium/net/dns/mdns_client_impl.cc +++ b/chromium/net/dns/mdns_client_impl.cc @@ -31,7 +31,6 @@ namespace net { namespace { -const unsigned MDnsTransactionTimeoutSeconds = 3; // The fractions of the record's original TTL after which an active listener // (one that had |SetActiveRefresh(true)| called) will send a query to refresh // its cache. This happens both at 85% of the original TTL and again at 95% of @@ -48,7 +47,7 @@ void MDnsSocketFactoryImpl::CreateSockets( DCHECK(interfaces[i].second == ADDRESS_FAMILY_IPV4 || interfaces[i].second == ADDRESS_FAMILY_IPV6); std::unique_ptr<DatagramServerSocket> socket(CreateAndBindMDnsSocket( - interfaces[i].second, interfaces[i].first, nullptr)); + interfaces[i].second, interfaces[i].first, net_log_)); if (socket) sockets->push_back(std::move(socket)); } @@ -271,8 +270,7 @@ void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, update_keys.insert(std::make_pair(update_key, update)); } - for (std::map<MDnsCache::Key, MDnsCache::UpdateType>::iterator i = - update_keys.begin(); i != update_keys.end(); i++) { + for (auto i = update_keys.begin(); i != update_keys.end(); i++) { const RecordParsed* record = cache_.LookupKey(i->first); if (!record) continue; @@ -298,8 +296,7 @@ void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now()); - for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); - i != records_to_remove.end(); i++) { + for (auto i = records_to_remove.begin(); i != records_to_remove.end(); i++) { if ((*i)->type() == dns_protocol::kTypeNSEC) continue; if (!rdata->GetBit((*i)->type())) { @@ -311,8 +308,7 @@ void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { } // Alert all listeners waiting for the nonexistent RR types. - ListenerMap::iterator i = - listeners_.upper_bound(ListenerKey(record->name(), 0)); + auto i = listeners_.upper_bound(ListenerKey(record->name(), 0)); for (; i != listeners_.end() && i->first.first == record->name(); i++) { if (!rdata->GetBit(i->first.second)) { for (auto& observer : *i->second) @@ -330,7 +326,7 @@ void MDnsClientImpl::Core::AlertListeners( MDnsCache::UpdateType update_type, const ListenerKey& key, const RecordParsed* record) { - ListenerMap::iterator listener_map_iterator = listeners_.find(key); + auto listener_map_iterator = listeners_.find(key); if (listener_map_iterator == listeners_.end()) return; for (auto& observer : *listener_map_iterator->second) @@ -350,7 +346,7 @@ void MDnsClientImpl::Core::AddListener( void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { ListenerKey key(listener->GetName(), listener->GetType()); - ListenerMap::iterator observer_list_iterator = listeners_.find(key); + auto observer_list_iterator = listeners_.find(key); DCHECK(observer_list_iterator != listeners_.end()); DCHECK(observer_list_iterator->second->HasObserver(listener)); @@ -368,7 +364,7 @@ void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { } void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { - ListenerMap::iterator found = listeners_.find(key); + auto found = listeners_.find(key); if (found != listeners_.end() && !found->second->might_have_observers()) { listeners_.erase(found); } @@ -690,8 +686,7 @@ void MDnsTransactionImpl::ServeRecordsFromCache() { if (client_->core()) { client_->core()->QueryCache(rrtype_, name_, &records); - for (std::vector<const RecordParsed*>::iterator i = records.begin(); - i != records.end() && weak_this; ++i) { + for (auto i = records.begin(); i != records.end() && weak_this; ++i) { weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i); } @@ -723,8 +718,7 @@ bool MDnsTransactionImpl::QueryAndListen() { timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, AsWeakPtr())); base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( - FROM_HERE, timeout_.callback(), - base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds)); + FROM_HERE, timeout_.callback(), kTransactionTimeout); return true; } diff --git a/chromium/net/dns/mdns_client_impl.h b/chromium/net/dns/mdns_client_impl.h index d9cd728341e..7230855f003 100644 --- a/chromium/net/dns/mdns_client_impl.h +++ b/chromium/net/dns/mdns_client_impl.h @@ -34,15 +34,20 @@ class OneShotTimer; namespace net { +class NetLog; + class MDnsSocketFactoryImpl : public MDnsSocketFactory { public: - MDnsSocketFactoryImpl() {} + MDnsSocketFactoryImpl() : net_log_(nullptr) {} + explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {} ~MDnsSocketFactoryImpl() override {} void CreateSockets( std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override; private: + NetLog* const net_log_; + DISALLOW_COPY_AND_ASSIGN(MDnsSocketFactoryImpl); }; diff --git a/chromium/net/dns/mock_host_resolver.cc b/chromium/net/dns/mock_host_resolver.cc index 9d33d4cc747..b5382ab1a69 100644 --- a/chromium/net/dns/mock_host_resolver.cc +++ b/chromium/net/dns/mock_host_resolver.cc @@ -19,6 +19,8 @@ #include "base/strings/string_util.h" #include "base/threading/platform_thread.h" #include "base/threading/thread_task_runner_handle.h" +#include "base/time/default_tick_clock.h" +#include "base/time/tick_clock.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -282,12 +284,13 @@ int MockHostResolverBase::ResolveStaleFromCache( next_request_id_++; int rv = ResolveFromIPLiteralOrCache( info.host_port_pair(), info.address_family(), info.host_resolver_flags(), - HostResolverSource::ANY, info.allow_cached_response(), addresses); + HostResolverSource::ANY, info.allow_cached_response(), addresses, + stale_info); return rv; } void MockHostResolverBase::DetachRequest(size_t id) { - RequestMap::iterator it = requests_.find(id); + auto it = requests_.find(id); CHECK(it != requests_.end()); requests_.erase(it); } @@ -309,7 +312,7 @@ bool MockHostResolverBase::HasCached( void MockHostResolverBase::ResolveAllPending() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK(ondemand_mode_); - for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) { + for (auto i = requests_.begin(); i != requests_.end(); ++i) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first)); @@ -323,10 +326,13 @@ MockHostResolverBase::MockHostResolverBase(bool use_caching) ondemand_mode_(false), next_request_id_(1), num_resolve_(0), - num_resolve_from_cache_(0) { + num_resolve_from_cache_(0), + tick_clock_(base::DefaultTickClock::GetInstance()) { rules_map_[HostResolverSource::ANY] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::SYSTEM] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::DNS] = CreateCatchAllHostResolverProc(); + rules_map_[HostResolverSource::MULTICAST_DNS] = + CreateCatchAllHostResolverProc(); if (use_caching) { cache_.reset(new HostCache(kMaxCacheEntries)); @@ -405,9 +411,9 @@ int MockHostResolverBase::ResolveFromIPLiteralOrCache( HostCache::Key key(host.host(), requested_address_family, flags, source); const HostCache::Entry* entry; if (stale_info) - entry = cache_->LookupStale(key, base::TimeTicks::Now(), stale_info); + entry = cache_->LookupStale(key, tick_clock_->NowTicks(), stale_info); else - entry = cache_->Lookup(key, base::TimeTicks::Now()); + entry = cache_->Lookup(key, tick_clock_->NowTicks()); if (entry) { rv = entry->error(); if (rv == OK) @@ -435,7 +441,7 @@ int MockHostResolverBase::ResolveProc(const HostPortPair& host, ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); cache_->Set(key, HostCache::Entry(rv, addr, HostCache::Entry::SOURCE_UNKNOWN), - base::TimeTicks::Now(), ttl); + tick_clock_->NowTicks(), ttl); } if (rv == OK) *addresses = AddressList::CopyWithPort(addr, host.port()); @@ -443,7 +449,7 @@ int MockHostResolverBase::ResolveProc(const HostPortPair& host, } void MockHostResolverBase::ResolveNow(size_t id) { - RequestMap::iterator it = requests_.find(id); + auto it = requests_.find(id); if (it == requests_.end()) return; // was canceled @@ -462,14 +468,13 @@ void MockHostResolverBase::ResolveNow(size_t id) { //----------------------------------------------------------------------------- -RuleBasedHostResolverProc::Rule::Rule( - ResolverType resolver_type, - const std::string& host_pattern, - AddressFamily address_family, - HostResolverFlags host_resolver_flags, - const std::string& replacement, - const std::string& canonical_name, - int latency_ms) +RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type, + const std::string& host_pattern, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + const std::string& replacement, + const std::string& canonical_name, + int latency_ms) : resolver_type(resolver_type), host_pattern(host_pattern), address_family(address_family), @@ -495,14 +500,9 @@ void RuleBasedHostResolverProc::AddRuleForAddressFamily( const std::string& replacement) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | - HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; - Rule rule(Rule::kResolverTypeSystem, - host_pattern, - address_family, - flags, - replacement, - std::string(), - 0); + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags, + replacement, std::string(), 0); AddRuleInternal(rule); } @@ -526,7 +526,7 @@ void RuleBasedHostResolverProc::AddIPLiteralRule( IPAddress ip_address; DCHECK(!ip_address.AssignFromIPLiteral(host_pattern)); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | - HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; if (!canonical_name.empty()) flags |= HOST_RESOLVER_CANONNAME; @@ -541,42 +541,27 @@ void RuleBasedHostResolverProc::AddRuleWithLatency( int latency_ms) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | - HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; - Rule rule(Rule::kResolverTypeSystem, - host_pattern, - ADDRESS_FAMILY_UNSPECIFIED, - flags, - replacement, - std::string(), - latency_ms); + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, + flags, replacement, std::string(), latency_ms); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AllowDirectLookup( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | - HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; - Rule rule(Rule::kResolverTypeSystem, - host_pattern, - ADDRESS_FAMILY_UNSPECIFIED, - flags, - std::string(), - std::string(), - 0); + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, + flags, std::string(), std::string(), 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddSimulatedFailure( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | - HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; - Rule rule(Rule::kResolverTypeFail, - host_pattern, - ADDRESS_FAMILY_UNSPECIFIED, - flags, - std::string(), - std::string(), - 0); + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, + flags, std::string(), std::string(), 0); AddRuleInternal(rule); } @@ -637,10 +622,9 @@ int RuleBasedHostResolverProc::Resolve(const std::string& host, #if defined(OS_WIN) EnsureWinsockInit(); #endif - return SystemHostResolverCall(effective_host, - address_family, - host_resolver_flags, - addrlist, os_error); + return SystemHostResolverCall(effective_host, address_family, + host_resolver_flags, addrlist, + os_error); case Rule::kResolverTypeIPLiteral: { AddressList raw_addr_list; int result = ParseAddressList( @@ -667,8 +651,8 @@ int RuleBasedHostResolverProc::Resolve(const std::string& host, } } } - return ResolveUsingPrevious(host, address_family, - host_resolver_flags, addrlist, os_error); + return ResolveUsingPrevious(host, address_family, host_resolver_flags, + addrlist, os_error); } RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default; diff --git a/chromium/net/dns/mock_host_resolver.h b/chromium/net/dns/mock_host_resolver.h index 892025d03b8..9cabbb4f188 100644 --- a/chromium/net/dns/mock_host_resolver.h +++ b/chromium/net/dns/mock_host_resolver.h @@ -22,6 +22,10 @@ #include "net/dns/host_resolver_proc.h" #include "net/dns/host_resolver_source.h" +namespace base { +class TickClock; +} // namespace base + namespace net { class HostCache; @@ -124,6 +128,7 @@ class MockHostResolverBase bool HasCached(base::StringPiece hostname, HostCache::Entry::Source* source_out, HostCache::EntryStaleness* stale_out) const override; + void SetDnsConfigOverrides(const DnsConfigOverrides& overrides) override {} // Detach cancelled request. void DetachRequest(size_t id); @@ -153,6 +158,10 @@ class MockHostResolverBase return last_request_priority_; } + void set_tick_clock(const base::TickClock* tick_clock) { + tick_clock_ = tick_clock; + } + protected: explicit MockHostResolverBase(bool use_caching); @@ -194,6 +203,8 @@ class MockHostResolverBase size_t num_resolve_; size_t num_resolve_from_cache_; + const base::TickClock* tick_clock_; + THREAD_CHECKER(thread_checker_); DISALLOW_COPY_AND_ASSIGN(MockHostResolverBase); diff --git a/chromium/net/dns/mock_mdns_client.cc b/chromium/net/dns/mock_mdns_client.cc new file mode 100644 index 00000000000..bbce796546d --- /dev/null +++ b/chromium/net/dns/mock_mdns_client.cc @@ -0,0 +1,17 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mock_mdns_client.h" + +namespace net { + +MockMDnsTransaction::MockMDnsTransaction() = default; + +MockMDnsTransaction::~MockMDnsTransaction() = default; + +MockMDnsClient::MockMDnsClient() = default; + +MockMDnsClient::~MockMDnsClient() = default; + +} // namespace net diff --git a/chromium/net/dns/mock_mdns_client.h b/chromium/net/dns/mock_mdns_client.h new file mode 100644 index 00000000000..0670caa7991 --- /dev/null +++ b/chromium/net/dns/mock_mdns_client.h @@ -0,0 +1,48 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MOCK_MDNS_CLIENT_H_ +#define NET_DNS_MOCK_MDNS_CLIENT_H_ + +#include <memory> +#include <string> + +#include "net/dns/mdns_client.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace net { + +class MockMDnsTransaction : public MDnsTransaction { + public: + MockMDnsTransaction(); + ~MockMDnsTransaction(); + + MOCK_METHOD0(Start, bool()); + MOCK_CONST_METHOD0(GetName, const std::string&()); + MOCK_CONST_METHOD0(GetType, uint16_t()); +}; + +class MockMDnsClient : public MDnsClient { + public: + MockMDnsClient(); + ~MockMDnsClient(); + + MOCK_METHOD3(CreateListener, + std::unique_ptr<MDnsListener>(uint16_t, + const std::string&, + MDnsListener::Delegate*)); + MOCK_METHOD4( + CreateTransaction, + std::unique_ptr<MDnsTransaction>(uint16_t, + const std::string&, + int, + const MDnsTransaction::ResultCallback&)); + MOCK_METHOD1(StartListening, bool(MDnsSocketFactory*)); + MOCK_METHOD0(StopListening, void()); + MOCK_CONST_METHOD0(IsListening, bool()); +}; + +} // namespace net + +#endif // NET_DNS_MOCK_MDNS_CLIENT_H_ diff --git a/chromium/net/dns/record_rdata.cc b/chromium/net/dns/record_rdata.cc index c8f86048232..86991b2287e 100644 --- a/chromium/net/dns/record_rdata.cc +++ b/chromium/net/dns/record_rdata.cc @@ -16,6 +16,26 @@ static const size_t kSrvRecordMinimumSize = 6; RecordRdata::RecordRdata() = default; +bool RecordRdata::HasValidSize(const base::StringPiece& data, uint16_t type) { + switch (type) { + case dns_protocol::kTypeSRV: + return data.size() >= kSrvRecordMinimumSize; + case dns_protocol::kTypeA: + return data.size() == IPAddress::kIPv4AddressSize; + case dns_protocol::kTypeAAAA: + return data.size() == IPAddress::kIPv6AddressSize; + case dns_protocol::kTypeCNAME: + case dns_protocol::kTypePTR: + case dns_protocol::kTypeTXT: + case dns_protocol::kTypeNSEC: + case dns_protocol::kTypeOPT: + return true; + default: + VLOG(1) << "Unsupported RDATA type."; + return false; + } +} + SrvRecordRdata::SrvRecordRdata() : priority_(0), weight_(0), port_(0) { } @@ -25,7 +45,7 @@ SrvRecordRdata::~SrvRecordRdata() = default; std::unique_ptr<SrvRecordRdata> SrvRecordRdata::Create( const base::StringPiece& data, const DnsRecordParser& parser) { - if (data.size() < kSrvRecordMinimumSize) + if (!HasValidSize(data, kType)) return std::unique_ptr<SrvRecordRdata>(); std::unique_ptr<SrvRecordRdata> rdata(new SrvRecordRdata); @@ -64,7 +84,7 @@ ARecordRdata::~ARecordRdata() = default; std::unique_ptr<ARecordRdata> ARecordRdata::Create( const base::StringPiece& data, const DnsRecordParser& parser) { - if (data.size() != IPAddress::kIPv4AddressSize) + if (!HasValidSize(data, kType)) return std::unique_ptr<ARecordRdata>(); std::unique_ptr<ARecordRdata> rdata(new ARecordRdata); @@ -91,7 +111,7 @@ AAAARecordRdata::~AAAARecordRdata() = default; std::unique_ptr<AAAARecordRdata> AAAARecordRdata::Create( const base::StringPiece& data, const DnsRecordParser& parser) { - if (data.size() != IPAddress::kIPv6AddressSize) + if (!HasValidSize(data, kType)) return std::unique_ptr<AAAARecordRdata>(); std::unique_ptr<AAAARecordRdata> rdata(new AAAARecordRdata); diff --git a/chromium/net/dns/record_rdata.h b/chromium/net/dns/record_rdata.h index dab5bf0b194..4e886a1f028 100644 --- a/chromium/net/dns/record_rdata.h +++ b/chromium/net/dns/record_rdata.h @@ -30,6 +30,10 @@ class NET_EXPORT_PRIVATE RecordRdata { public: virtual ~RecordRdata() {} + // Return true if |data| represents RDATA in the wire format with a valid size + // for the give |type|. + static bool HasValidSize(const base::StringPiece& data, uint16_t type); + virtual bool IsEqual(const RecordRdata* other) const = 0; virtual uint16_t Type() const = 0; diff --git a/chromium/net/dns/serial_worker.cc b/chromium/net/dns/serial_worker.cc index f650605bf9e..815c0d88e0e 100644 --- a/chromium/net/dns/serial_worker.cc +++ b/chromium/net/dns/serial_worker.cc @@ -11,7 +11,7 @@ namespace net { -SerialWorker::SerialWorker() : state_(IDLE) {} +SerialWorker::SerialWorker() : state_(IDLE), weak_factory_(this) {} SerialWorker::~SerialWorker() = default; @@ -19,11 +19,16 @@ void SerialWorker::WorkNow() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); switch (state_) { case IDLE: + // We are posting weak pointer to OnWorkJobFinished to avoid leak when + // PostTaskWithTraitsAndReply fails to post task back to the original + // task runner. In this case the callback is not destroyed, and the + // weak reference allows SerialWorker instance to be deleted. base::PostTaskWithTraitsAndReply( FROM_HERE, {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, base::BindOnce(&SerialWorker::DoWork, this), - base::BindOnce(&SerialWorker::OnWorkJobFinished, this)); + base::BindOnce(&SerialWorker::OnWorkJobFinished, + weak_factory_.GetWeakPtr())); state_ = WORKING; return; case WORKING: diff --git a/chromium/net/dns/serial_worker.h b/chromium/net/dns/serial_worker.h index 6d2571f161a..dd2a4f8939e 100644 --- a/chromium/net/dns/serial_worker.h +++ b/chromium/net/dns/serial_worker.h @@ -10,6 +10,7 @@ #include "base/compiler_specific.h" #include "base/macros.h" #include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" #include "base/sequence_checker.h" #include "base/task/task_traits.h" #include "net/base/net_export.h" @@ -74,6 +75,8 @@ class NET_EXPORT_PRIVATE SerialWorker State state_; + base::WeakPtrFactory<SerialWorker> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(SerialWorker); }; |