diff options
| author | Luke Huang <huangluke@google.com> | 2019-06-24 13:28:44 +0800 |
|---|---|---|
| committer | Luke Huang <huangluke@google.com> | 2019-07-03 15:38:40 +0800 |
| commit | 6898d5bb33dbbcc8bb263c79435c4f79db61e52c (patch) | |
| tree | ff8240819d3d869e142154d3b1bd3c6b0d03b400 | |
| parent | 481ddf09b6cead845f65b29322a83079c8c350c9 (diff) | |
Use std::vector to store domains of nameservers and minor change
1.
Drop the old C style used to store domains.
Previously, resolv is limited to use 6 search domains with total 255 length.
(including zero padding)
After this change, the length of each domain could exactly be at most 255. (rfc 1035)
Also, invalid or duplicate domains will be dropped.
2. rename resolv_set_nameservers_for_net to resolv_set_nameservers
Bug: 135506574
Test: cd system/netd && atest
Change-Id: I94129ea521522c817d087332a7b467f616cc4895
| -rw-r--r-- | resolv/ResolverController.cpp | 10 | ||||
| -rw-r--r-- | resolv/getaddrinfo.cpp | 6 | ||||
| -rw-r--r-- | resolv/libnetd_resolv_test.cpp | 66 | ||||
| -rw-r--r-- | resolv/res_cache.cpp | 94 | ||||
| -rw-r--r-- | resolv/res_init.cpp | 29 | ||||
| -rw-r--r-- | resolv/res_query.cpp | 11 | ||||
| -rw-r--r-- | resolv/resolv_cache.h | 8 | ||||
| -rw-r--r-- | resolv/resolv_private.h | 25 | ||||
| -rw-r--r-- | resolv/resolver_test.cpp | 86 |
9 files changed, 189 insertions, 146 deletions
diff --git a/resolv/ResolverController.cpp b/resolv/ResolverController.cpp index ceb76393..7347a44a 100644 --- a/resolv/ResolverController.cpp +++ b/resolv/ResolverController.cpp @@ -228,9 +228,7 @@ int ResolverController::setResolverConfiguration( server_ptrs.push_back(resolverParams.servers[i].c_str()); } - std::string domains_str = android::base::Join(resolverParams.domains, " "); - - // TODO: Change resolv_set_nameservers_for_net() to use ResolverParamsParcel directly. + // TODO: Change resolv_set_nameservers() to use ResolverParamsParcel directly. res_params res_params = {}; res_params.sample_validity = resolverParams.sampleValiditySeconds; res_params.success_threshold = resolverParams.successThreshold; @@ -240,10 +238,10 @@ int ResolverController::setResolverConfiguration( res_params.retry_count = resolverParams.retryCount; LOG(VERBOSE) << "setDnsServers netId = " << resolverParams.netId - << ", numservers = " << resolverParams.domains.size(); + << ", numservers = " << resolverParams.servers.size(); - return -resolv_set_nameservers_for_net(resolverParams.netId, server_ptrs.data(), - server_ptrs.size(), domains_str.c_str(), &res_params); + return -resolv_set_nameservers(resolverParams.netId, server_ptrs.data(), server_ptrs.size(), + resolverParams.domains, &res_params); } int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* servers, diff --git a/resolv/getaddrinfo.cpp b/resolv/getaddrinfo.cpp index 592b5a3c..9f43ec8b 100644 --- a/resolv/getaddrinfo.cpp +++ b/resolv/getaddrinfo.cpp @@ -1679,7 +1679,7 @@ static int res_queryN(const char* name, res_target* target, res_state res, int* * is detected. Error code, if any, is left in *herrno. */ static int res_searchN(const char* name, res_target* target, res_state res, int* herrno) { - const char *cp, *const *domain; + const char* cp; HEADER* hp; u_int dots; int ret, saved_herrno; @@ -1722,8 +1722,8 @@ static int res_searchN(const char* name, res_target* target, res_state res, int* */ _resolv_populate_res_for_net(res); - for (domain = (const char* const*) res->dnsrch; *domain && !done; domain++) { - ret = res_querydomainN(name, *domain, target, res, herrno); + for (const auto& domain : res->search_domains) { + ret = res_querydomainN(name, domain.c_str(), target, res, herrno); if (ret > 0) return ret; /* diff --git a/resolv/libnetd_resolv_test.cpp b/resolv/libnetd_resolv_test.cpp index 29984eb7..d6fd80d8 100644 --- a/resolv/libnetd_resolv_test.cpp +++ b/resolv/libnetd_resolv_test.cpp @@ -368,8 +368,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname_NoData) { dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3"); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); dns.clearQueries(); // Want AAAA answer but DNS server has A answer only. @@ -395,8 +395,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname) { dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { int ai_family; @@ -430,8 +430,8 @@ TEST_F(ResolvGetAddrInfoTest, IllegalHostname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp. static constexpr char const* illegalHostnames[] = { @@ -497,8 +497,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerResponseError) { dns.setResponseProbability(0.0); // always ignore requests and response preset rcode ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); addrinfo* result = nullptr; const addrinfo hints = {.ai_family = AF_UNSPEC}; @@ -519,8 +519,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerTimeout) { dns.setResponseProbability(0.0); // always ignore requests and don't response ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); addrinfo* result = nullptr; const addrinfo hints = {.ai_family = AF_UNSPEC}; @@ -543,8 +543,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesNoIpAddress) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -582,8 +582,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesBrokenChainByIllegalCname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -642,8 +642,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesInfiniteLoop) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) { SCOPED_TRACE(StringPrintf("family: %d", family)); @@ -670,8 +670,8 @@ TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) { dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { int ai_family; @@ -704,8 +704,8 @@ TEST_F(GetHostByNameForNetContextTest, IllegalHostname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp. static constexpr char const* illegalHostnames[] = { @@ -749,8 +749,8 @@ TEST_F(GetHostByNameForNetContextTest, NoData) { dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3"); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); dns.clearQueries(); // Want AAAA answer but DNS server has A answer only. @@ -793,8 +793,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerResponseError) { dns.setResponseProbability(0.0); // always ignore requests and response preset rcode ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); hostent* hp = nullptr; NetworkDnsEventReported event; @@ -814,8 +814,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerTimeout) { dns.setResponseProbability(0.0); // always ignore requests and don't response ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); hostent* hp = nullptr; NetworkDnsEventReported event; @@ -836,8 +836,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesNoIpAddress) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -870,8 +870,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesBrokenChainByIllegalCname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -929,8 +929,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); for (const auto& family : {AF_INET, AF_INET6}) { SCOPED_TRACE(StringPrintf("family: %d", family)); @@ -946,7 +946,7 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) { // Note that local host file function, files_getaddrinfo(), of resolv_getaddrinfo() // is not tested because it only returns a boolean (success or failure) without any error number. -// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers_for_net, as +// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers, as // ResolverTest does. // TODO: Add test for resolv_getaddrinfo(). // - DNS response message parsing. diff --git a/resolv/res_cache.cpp b/resolv/res_cache.cpp index abf90110..befcbed8 100644 --- a/resolv/res_cache.cpp +++ b/resolv/res_cache.cpp @@ -35,7 +35,10 @@ #include <stdlib.h> #include <string.h> #include <time.h> +#include <algorithm> #include <mutex> +#include <set> +#include <vector> #include <arpa/inet.h> #include <arpa/nameser.h> @@ -46,6 +49,7 @@ #include <android-base/logging.h> #include <android-base/parseint.h> +#include <android-base/strings.h> #include <android-base/thread_annotations.h> #include <android/multinetwork.h> // ResNsendFlags @@ -1142,8 +1146,7 @@ struct resolv_cache_info { int revision_id; // # times the nameservers have been replaced res_params params; struct res_stats nsstats[MAXNS]; - char defdname[MAXDNSRCHPATH]; - int dnsrch_offset[MAXDNSRCH + 1]; // offsets into defdname + std::vector<std::string> search_domains; int wait_for_pending_req_timeout_count; }; @@ -1719,12 +1722,35 @@ static void resolv_set_experiment_params(res_params* params) { } } -int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const int numservers, - const char* domains, const res_params* params) { - char* cp; - int* offset; - struct addrinfo* nsaddrinfo[MAXNS]; +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const char* domains, const res_params* params) { + return resolv_set_nameservers(netid, servers, numservers, android::base::Split(domains, " "), + params); +} + +namespace { +// Returns valid domains without duplicates which are limited to max size |MAXDNSRCH|. +std::vector<std::string> filter_domains(const std::vector<std::string>& domains) { + std::set<std::string> tmp_set; + std::vector<std::string> res; + + std::copy_if(domains.begin(), domains.end(), std::back_inserter(res), + [&tmp_set](const std::string& str) { + return !(str.size() > MAXDNSRCHPATH - 1) && (tmp_set.insert(str).second); + }); + if (res.size() > MAXDNSRCH) { + LOG(WARNING) << __func__ << ": valid domains=" << res.size() + << ", but MAXDNSRCH=" << MAXDNSRCH; + res.resize(MAXDNSRCH); + } + return res; +} + +} // namespace + +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const std::vector<std::string>& domains, const res_params* params) { if (numservers > MAXNS) { LOG(ERROR) << __func__ << ": numservers=" << numservers << ", MAXNS=" << MAXNS; return E2BIG; @@ -1732,6 +1758,8 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i // Parse the addresses before actually locking or changing any state, in case there is an error. // As a side effect this also reduces the time the lock is kept. + // TODO: find a better way to replace addrinfo*, something like std::vector<SafeAddrinfo> + addrinfo* nsaddrinfo[MAXNS]; char sbuf[NI_MAXSERV]; snprintf(sbuf, sizeof(sbuf), "%u", NAMESERVER_PORT); for (int i = 0; i < numservers; i++) { @@ -1795,31 +1823,9 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i } } - // Always update the search paths, since determining whether they actually changed is - // complex due to the zero-padding, and probably not worth the effort. Cache-flushing - // however is not necessary, since the stored cache entries do contain the domain, not - // just the host name. - strlcpy(cache_info->defdname, domains, sizeof(cache_info->defdname)); - if ((cp = strchr(cache_info->defdname, '\n')) != NULL) *cp = '\0'; - LOG(INFO) << __func__ << ": domains=\"" << cache_info->defdname << "\""; - - cp = cache_info->defdname; - offset = cache_info->dnsrch_offset; - while (offset < cache_info->dnsrch_offset + MAXDNSRCH) { - while (*cp == ' ' || *cp == '\t') /* skip leading white space */ - cp++; - if (*cp == '\0') /* stop if nothing more to do */ - break; - *offset++ = cp - cache_info->defdname; /* record this search domain */ - while (*cp) { /* zero-terminate it */ - if (*cp == ' ' || *cp == '\t') { - *cp++ = '\0'; - break; - } - cp++; - } - } - *offset = -1; /* cache_info->dnsrch_offset has MAXDNSRCH+1 items */ + // Always update the search paths. Cache-flushing however is not necessary, + // since the stored cache entries do contain the domain, not just the host name. + cache_info->search_domains = filter_domains(domains); return 0; } @@ -1899,15 +1905,7 @@ void _resolv_populate_res_for_net(res_state statp) { } } statp->nscount = nserv; - // now do search domains. Note that we cache the offsets as this code runs alot - // but the setting/offset-computer only runs when set/changed - // WARNING: Don't use str*cpy() here, this string contains zeroes. - memcpy(statp->defdname, info->defdname, sizeof(statp->defdname)); - char** pp = statp->dnsrch; - int* p = info->dnsrch_offset; - while (pp < statp->dnsrch + MAXDNSRCH && *p != -1) { - *pp++ = &statp->defdname[0] + *p++; - } + statp->search_domains = info->search_domains; } } @@ -1979,17 +1977,9 @@ int android_net_res_stats_get_info_for_net(unsigned netid, int* nscount, memcpy(&servers[i], info->nsaddrinfo[i]->ai_addr, info->nsaddrinfo[i]->ai_addrlen); stats[i] = info->nsstats[i]; } - for (i = 0; i < MAXDNSRCH; i++) { - const char* cur_domain = info->defdname + info->dnsrch_offset[i]; - // dnsrch_offset[i] can either be -1 or point to an empty string to indicate the end - // of the search offsets. Checking for < 0 is not strictly necessary, but safer. - // TODO: Pass in a search domain array instead of a string to - // resolv_set_nameservers_for_net() and make this double check unnecessary. - if (info->dnsrch_offset[i] < 0 || - ((size_t) info->dnsrch_offset[i]) >= sizeof(info->defdname) || !cur_domain[0]) { - break; - } - strlcpy(domains[i], cur_domain, MAXDNSRCHPATH); + + for (i = 0; i < static_cast<int>(info->search_domains.size()); i++) { + strlcpy(domains[i], info->search_domains[i].c_str(), MAXDNSRCHPATH); } *dcount = i; *params = info->params; diff --git a/resolv/res_init.cpp b/resolv/res_init.cpp index 04ac3aaa..6ce51403 100644 --- a/resolv/res_init.cpp +++ b/resolv/res_init.cpp @@ -125,11 +125,7 @@ int res_ninit(res_state statp) { /* This function has to be reachable by res_data.c but not publicly. */ int res_vinit(res_state statp, int preinit) { - char *cp, **pp; - char buf[BUFSIZ]; int nserv = 0; /* number of nameserver records read from file */ - int havesearch = 0; - int dots; sockaddr_union u[2]; if ((statp->options & RES_INIT) != 0U) res_ndestroy(statp); @@ -162,31 +158,6 @@ int res_vinit(res_state statp, int preinit) { statp->nsort = 0; res_setservers(statp, u, nserv); - if (statp->defdname[0] == 0 && gethostname(buf, sizeof(statp->defdname) - 1) == 0 && - (cp = strchr(buf, '.')) != NULL) - strcpy(statp->defdname, cp + 1); - - /* find components of local domain that might be searched */ - if (havesearch == 0) { - pp = statp->dnsrch; - *pp++ = statp->defdname; - *pp = NULL; - - dots = 0; - for (cp = statp->defdname; *cp; cp++) dots += (*cp == '.'); - - cp = statp->defdname; - while (pp < statp->dnsrch + MAXDFLSRCH) { - if (dots < LOCALDOMAINPARTS) break; - cp = strchr(cp, '.') + 1; /* we know there is one */ - *pp++ = cp; - dots--; - } - *pp = NULL; - LOG(DEBUG) << __func__ << ": dnsrch list:"; - for (pp = statp->dnsrch; *pp; pp++) LOG(DEBUG) << "\t" << *pp; - } - if (nserv > 0) { statp->nscount = nserv; statp->options |= RES_INIT; diff --git a/resolv/res_query.cpp b/resolv/res_query.cpp index 13696c48..37767725 100644 --- a/resolv/res_query.cpp +++ b/resolv/res_query.cpp @@ -207,7 +207,7 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */ { - const char *cp, *const *domain; + const char* cp; HEADER* hp = (HEADER*) (void*) answer; u_int dots; int ret, saved_herrno; @@ -251,12 +251,10 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ */ _resolv_populate_res_for_net(statp); - for (domain = (const char* const*) statp->dnsrch; *domain && !done; domain++) { + for (const auto& domain : statp->search_domains) { + if (domain == "." || domain == "") ++root_on_list; - if (domain[0][0] == '\0' || (domain[0][0] == '.' && domain[0][1] == '\0')) - root_on_list++; - - ret = res_nquerydomain(statp, name, *domain, cl, type, answer, anslen, herrno); + ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, anslen, herrno); if (ret > 0) return ret; /* @@ -295,7 +293,6 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ /* anything else implies that we're done */ done++; } - } } diff --git a/resolv/resolv_cache.h b/resolv/resolv_cache.h index 246aa910..1a27e4b0 100644 --- a/resolv/resolv_cache.h +++ b/resolv/resolv_cache.h @@ -64,8 +64,12 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags); // Sets name servers for a given network. -int resolv_set_nameservers_for_net(unsigned netid, const char** servers, int numservers, - const char* domains, const res_params* params); +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const std::vector<std::string>& domains, const res_params* params); + +// TODO: remove it after updating all callers. +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const char* domains, const res_params* params); // Creates the cache associated with the given network. int resolv_create_cache_for_net(unsigned netid); diff --git a/resolv/resolv_private.h b/resolv/resolv_private.h index 45ae337b..1e10849b 100644 --- a/resolv/resolv_private.h +++ b/resolv/resolv_private.h @@ -59,6 +59,7 @@ #include <resolv.h> #include <time.h> #include <string> +#include <vector> #include "netd_resolv/params.h" #include "netd_resolv/resolv.h" @@ -76,9 +77,6 @@ /* * Global defines and variables for resolver stub. */ -#define MAXDFLSRCH 3 /* # default domain levels to try */ -#define LOCALDOMAINPARTS 2 /* min levels in name that is "local" */ - #define RES_TIMEOUT 5000 /* min. milliseconds between retries */ #define MAXRESOLVSORT 10 /* number of net to sort on */ #define RES_MAXNDOTS 15 /* should reflect bit field size */ @@ -88,17 +86,16 @@ struct res_state_ext; struct __res_state { - unsigned netid; /* NetId: cache key and socket mark */ - uid_t uid; /* uid of the app that sent the DNS lookup */ - u_long options; /* option flags - see below. */ - int nscount; /* number of name srvers */ - struct sockaddr_in nsaddr_list[MAXNS]; /* address of name server */ -#define nsaddr nsaddr_list[0] /* for backward compatibility */ - u_short id; /* current message id */ - char* dnsrch[MAXDNSRCH + 1]; /* components of domain to search */ - char defdname[256]; /* default domain (deprecated) */ - unsigned ndots : 4; /* threshold for initial abs. query */ - unsigned nsort : 4; /* number of elements in sort_list[] */ + unsigned netid; // NetId: cache key and socket mark + uid_t uid; // uid of the app that sent the DNS lookup + u_long options; // option flags - see below. + int nscount; // number of name srvers + struct sockaddr_in nsaddr_list[MAXNS]; // address of name server +#define nsaddr nsaddr_list[0] // for backward compatibility + u_short id; // current message id + std::vector<std::string> search_domains; // domains to search + unsigned ndots : 4; // threshold for initial abs. query + unsigned nsort : 4; // number of elements in sort_list[] char unused[3]; struct { struct in_addr addr; diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp index 6ca5e1a5..f2765fee 100644 --- a/resolv/resolver_test.cpp +++ b/resolv/resolver_test.cpp @@ -968,6 +968,92 @@ TEST_F(ResolverTest, SearchPathChange) { EXPECT_EQ("2001:db8::1:13", ToString(result)); } +namespace { + +std::vector<std::string> getResolverDomains(android::net::IDnsResolver* dnsResolverService, + unsigned netId) { + std::vector<std::string> res_servers; + std::vector<std::string> res_domains; + std::vector<std::string> res_tls_servers; + res_params res_params; + std::vector<ResolverStats> res_stats; + int wait_for_pending_req_timeout_count; + GetResolverInfo(dnsResolverService, netId, &res_servers, &res_domains, &res_tls_servers, + &res_params, &res_stats, &wait_for_pending_req_timeout_count); + return res_domains; +} + +} // namespace + +TEST_F(ResolverTest, SearchPathPrune) { + constexpr size_t DUPLICATED_DOMAIN_NUM = 3; + constexpr char listen_addr[] = "127.0.0.13"; + constexpr char domian_name1[] = "domain13.org."; + constexpr char domian_name2[] = "domain14.org."; + constexpr char host_name1[] = "test13.domain13.org."; + constexpr char host_name2[] = "test14.domain14.org."; + std::vector<std::string> servers = {listen_addr}; + + std::vector<std::string> testDomains1; + std::vector<std::string> testDomains2; + // Domain length should be <= 255 + // Max number of domains in search path is 6 + for (size_t i = 0; i < MAXDNSRCH + 1; i++) { + // Fill up with invalid domain + testDomains1.push_back(std::string(300, i + '0')); + // Fill up with valid but duplicated domain + testDomains2.push_back(StringPrintf("domain%zu.org", i % DUPLICATED_DOMAIN_NUM)); + } + + // Add valid domain used for query. + testDomains1.push_back(domian_name1); + + // Add valid domain twice used for query. + testDomains2.push_back(domian_name2); + testDomains2.push_back(domian_name2); + + const std::vector<DnsRecord> records = { + {host_name1, ns_type::ns_t_aaaa, "2001:db8::13"}, + {host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13"}, + }; + test::DNSResponder dns(listen_addr); + StartDns(dns, records); + ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains1)); + + const addrinfo hints = {.ai_family = AF_INET6}; + ScopedAddrinfo result = safe_getaddrinfo("test13", nullptr, &hints); + + EXPECT_TRUE(result != nullptr); + + EXPECT_EQ(1U, dns.queries().size()); + EXPECT_EQ(1U, GetNumQueries(dns, host_name1)); + EXPECT_EQ("2001:db8::13", ToString(result)); + + const auto& res_domains1 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID); + // Expect 1 valid domain, invalid domains are removed. + ASSERT_EQ(1U, res_domains1.size()); + EXPECT_EQ(domian_name1, res_domains1[0]); + + dns.clearQueries(); + + ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains2)); + + result = safe_getaddrinfo("test14", nullptr, &hints); + EXPECT_TRUE(result != nullptr); + + // (3 domains * 2 retries) + 1 success query = 7 + EXPECT_EQ(7U, dns.queries().size()); + EXPECT_EQ(1U, GetNumQueries(dns, host_name2)); + EXPECT_EQ("2001:db8::1:13", ToString(result)); + + const auto& res_domains2 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID); + // Expect 4 valid domain, duplicate domains are removed. + EXPECT_EQ(DUPLICATED_DOMAIN_NUM + 1U, res_domains2.size()); + EXPECT_THAT( + std::vector<std::string>({"domain0.org", "domain1.org", "domain2.org", domian_name2}), + testing::ElementsAreArray(res_domains2)); +} + static std::string base64Encode(const std::vector<uint8_t>& input) { size_t out_len; EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size())); |
