diff options
| -rw-r--r-- | resolv/ResolverController.cpp | 10 | ||||
| -rw-r--r-- | resolv/getaddrinfo.cpp | 6 | ||||
| -rw-r--r-- | resolv/libnetd_resolv_test.cpp | 66 | ||||
| -rw-r--r-- | resolv/res_cache.cpp | 94 | ||||
| -rw-r--r-- | resolv/res_init.cpp | 29 | ||||
| -rw-r--r-- | resolv/res_query.cpp | 11 | ||||
| -rw-r--r-- | resolv/resolv_cache.h | 8 | ||||
| -rw-r--r-- | resolv/resolv_private.h | 25 | ||||
| -rw-r--r-- | resolv/resolver_test.cpp | 86 |
9 files changed, 189 insertions, 146 deletions
diff --git a/resolv/ResolverController.cpp b/resolv/ResolverController.cpp index ceb76393..7347a44a 100644 --- a/resolv/ResolverController.cpp +++ b/resolv/ResolverController.cpp @@ -228,9 +228,7 @@ int ResolverController::setResolverConfiguration( server_ptrs.push_back(resolverParams.servers[i].c_str()); } - std::string domains_str = android::base::Join(resolverParams.domains, " "); - - // TODO: Change resolv_set_nameservers_for_net() to use ResolverParamsParcel directly. + // TODO: Change resolv_set_nameservers() to use ResolverParamsParcel directly. res_params res_params = {}; res_params.sample_validity = resolverParams.sampleValiditySeconds; res_params.success_threshold = resolverParams.successThreshold; @@ -240,10 +238,10 @@ int ResolverController::setResolverConfiguration( res_params.retry_count = resolverParams.retryCount; LOG(VERBOSE) << "setDnsServers netId = " << resolverParams.netId - << ", numservers = " << resolverParams.domains.size(); + << ", numservers = " << resolverParams.servers.size(); - return -resolv_set_nameservers_for_net(resolverParams.netId, server_ptrs.data(), - server_ptrs.size(), domains_str.c_str(), &res_params); + return -resolv_set_nameservers(resolverParams.netId, server_ptrs.data(), server_ptrs.size(), + resolverParams.domains, &res_params); } int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* servers, diff --git a/resolv/getaddrinfo.cpp b/resolv/getaddrinfo.cpp index 592b5a3c..9f43ec8b 100644 --- a/resolv/getaddrinfo.cpp +++ b/resolv/getaddrinfo.cpp @@ -1679,7 +1679,7 @@ static int res_queryN(const char* name, res_target* target, res_state res, int* * is detected. Error code, if any, is left in *herrno. */ static int res_searchN(const char* name, res_target* target, res_state res, int* herrno) { - const char *cp, *const *domain; + const char* cp; HEADER* hp; u_int dots; int ret, saved_herrno; @@ -1722,8 +1722,8 @@ static int res_searchN(const char* name, res_target* target, res_state res, int* */ _resolv_populate_res_for_net(res); - for (domain = (const char* const*) res->dnsrch; *domain && !done; domain++) { - ret = res_querydomainN(name, *domain, target, res, herrno); + for (const auto& domain : res->search_domains) { + ret = res_querydomainN(name, domain.c_str(), target, res, herrno); if (ret > 0) return ret; /* diff --git a/resolv/libnetd_resolv_test.cpp b/resolv/libnetd_resolv_test.cpp index 29984eb7..d6fd80d8 100644 --- a/resolv/libnetd_resolv_test.cpp +++ b/resolv/libnetd_resolv_test.cpp @@ -368,8 +368,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname_NoData) { dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3"); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); dns.clearQueries(); // Want AAAA answer but DNS server has A answer only. @@ -395,8 +395,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname) { dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { int ai_family; @@ -430,8 +430,8 @@ TEST_F(ResolvGetAddrInfoTest, IllegalHostname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp. static constexpr char const* illegalHostnames[] = { @@ -497,8 +497,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerResponseError) { dns.setResponseProbability(0.0); // always ignore requests and response preset rcode ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); addrinfo* result = nullptr; const addrinfo hints = {.ai_family = AF_UNSPEC}; @@ -519,8 +519,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerTimeout) { dns.setResponseProbability(0.0); // always ignore requests and don't response ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); addrinfo* result = nullptr; const addrinfo hints = {.ai_family = AF_UNSPEC}; @@ -543,8 +543,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesNoIpAddress) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -582,8 +582,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesBrokenChainByIllegalCname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -642,8 +642,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesInfiniteLoop) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) { SCOPED_TRACE(StringPrintf("family: %d", family)); @@ -670,8 +670,8 @@ TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) { dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { int ai_family; @@ -704,8 +704,8 @@ TEST_F(GetHostByNameForNetContextTest, IllegalHostname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp. static constexpr char const* illegalHostnames[] = { @@ -749,8 +749,8 @@ TEST_F(GetHostByNameForNetContextTest, NoData) { dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3"); ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); dns.clearQueries(); // Want AAAA answer but DNS server has A answer only. @@ -793,8 +793,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerResponseError) { dns.setResponseProbability(0.0); // always ignore requests and response preset rcode ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); hostent* hp = nullptr; NetworkDnsEventReported event; @@ -814,8 +814,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerTimeout) { dns.setResponseProbability(0.0); // always ignore requests and don't response ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); hostent* hp = nullptr; NetworkDnsEventReported event; @@ -836,8 +836,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesNoIpAddress) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -870,8 +870,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesBrokenChainByIllegalCname) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); static const struct TestConfig { const char* name; @@ -929,8 +929,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) { ASSERT_TRUE(dns.startServer()); const char* servers[] = {listen_addr}; - ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers), - mDefaultSearchDomains, &mDefaultParams_Binder)); + ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers), + mDefaultSearchDomains, &mDefaultParams_Binder)); for (const auto& family : {AF_INET, AF_INET6}) { SCOPED_TRACE(StringPrintf("family: %d", family)); @@ -946,7 +946,7 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) { // Note that local host file function, files_getaddrinfo(), of resolv_getaddrinfo() // is not tested because it only returns a boolean (success or failure) without any error number. -// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers_for_net, as +// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers, as // ResolverTest does. // TODO: Add test for resolv_getaddrinfo(). // - DNS response message parsing. diff --git a/resolv/res_cache.cpp b/resolv/res_cache.cpp index abf90110..befcbed8 100644 --- a/resolv/res_cache.cpp +++ b/resolv/res_cache.cpp @@ -35,7 +35,10 @@ #include <stdlib.h> #include <string.h> #include <time.h> +#include <algorithm> #include <mutex> +#include <set> +#include <vector> #include <arpa/inet.h> #include <arpa/nameser.h> @@ -46,6 +49,7 @@ #include <android-base/logging.h> #include <android-base/parseint.h> +#include <android-base/strings.h> #include <android-base/thread_annotations.h> #include <android/multinetwork.h> // ResNsendFlags @@ -1142,8 +1146,7 @@ struct resolv_cache_info { int revision_id; // # times the nameservers have been replaced res_params params; struct res_stats nsstats[MAXNS]; - char defdname[MAXDNSRCHPATH]; - int dnsrch_offset[MAXDNSRCH + 1]; // offsets into defdname + std::vector<std::string> search_domains; int wait_for_pending_req_timeout_count; }; @@ -1719,12 +1722,35 @@ static void resolv_set_experiment_params(res_params* params) { } } -int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const int numservers, - const char* domains, const res_params* params) { - char* cp; - int* offset; - struct addrinfo* nsaddrinfo[MAXNS]; +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const char* domains, const res_params* params) { + return resolv_set_nameservers(netid, servers, numservers, android::base::Split(domains, " "), + params); +} + +namespace { +// Returns valid domains without duplicates which are limited to max size |MAXDNSRCH|. +std::vector<std::string> filter_domains(const std::vector<std::string>& domains) { + std::set<std::string> tmp_set; + std::vector<std::string> res; + + std::copy_if(domains.begin(), domains.end(), std::back_inserter(res), + [&tmp_set](const std::string& str) { + return !(str.size() > MAXDNSRCHPATH - 1) && (tmp_set.insert(str).second); + }); + if (res.size() > MAXDNSRCH) { + LOG(WARNING) << __func__ << ": valid domains=" << res.size() + << ", but MAXDNSRCH=" << MAXDNSRCH; + res.resize(MAXDNSRCH); + } + return res; +} + +} // namespace + +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const std::vector<std::string>& domains, const res_params* params) { if (numservers > MAXNS) { LOG(ERROR) << __func__ << ": numservers=" << numservers << ", MAXNS=" << MAXNS; return E2BIG; @@ -1732,6 +1758,8 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i // Parse the addresses before actually locking or changing any state, in case there is an error. // As a side effect this also reduces the time the lock is kept. + // TODO: find a better way to replace addrinfo*, something like std::vector<SafeAddrinfo> + addrinfo* nsaddrinfo[MAXNS]; char sbuf[NI_MAXSERV]; snprintf(sbuf, sizeof(sbuf), "%u", NAMESERVER_PORT); for (int i = 0; i < numservers; i++) { @@ -1795,31 +1823,9 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i } } - // Always update the search paths, since determining whether they actually changed is - // complex due to the zero-padding, and probably not worth the effort. Cache-flushing - // however is not necessary, since the stored cache entries do contain the domain, not - // just the host name. - strlcpy(cache_info->defdname, domains, sizeof(cache_info->defdname)); - if ((cp = strchr(cache_info->defdname, '\n')) != NULL) *cp = '\0'; - LOG(INFO) << __func__ << ": domains=\"" << cache_info->defdname << "\""; - - cp = cache_info->defdname; - offset = cache_info->dnsrch_offset; - while (offset < cache_info->dnsrch_offset + MAXDNSRCH) { - while (*cp == ' ' || *cp == '\t') /* skip leading white space */ - cp++; - if (*cp == '\0') /* stop if nothing more to do */ - break; - *offset++ = cp - cache_info->defdname; /* record this search domain */ - while (*cp) { /* zero-terminate it */ - if (*cp == ' ' || *cp == '\t') { - *cp++ = '\0'; - break; - } - cp++; - } - } - *offset = -1; /* cache_info->dnsrch_offset has MAXDNSRCH+1 items */ + // Always update the search paths. Cache-flushing however is not necessary, + // since the stored cache entries do contain the domain, not just the host name. + cache_info->search_domains = filter_domains(domains); return 0; } @@ -1899,15 +1905,7 @@ void _resolv_populate_res_for_net(res_state statp) { } } statp->nscount = nserv; - // now do search domains. Note that we cache the offsets as this code runs alot - // but the setting/offset-computer only runs when set/changed - // WARNING: Don't use str*cpy() here, this string contains zeroes. - memcpy(statp->defdname, info->defdname, sizeof(statp->defdname)); - char** pp = statp->dnsrch; - int* p = info->dnsrch_offset; - while (pp < statp->dnsrch + MAXDNSRCH && *p != -1) { - *pp++ = &statp->defdname[0] + *p++; - } + statp->search_domains = info->search_domains; } } @@ -1979,17 +1977,9 @@ int android_net_res_stats_get_info_for_net(unsigned netid, int* nscount, memcpy(&servers[i], info->nsaddrinfo[i]->ai_addr, info->nsaddrinfo[i]->ai_addrlen); stats[i] = info->nsstats[i]; } - for (i = 0; i < MAXDNSRCH; i++) { - const char* cur_domain = info->defdname + info->dnsrch_offset[i]; - // dnsrch_offset[i] can either be -1 or point to an empty string to indicate the end - // of the search offsets. Checking for < 0 is not strictly necessary, but safer. - // TODO: Pass in a search domain array instead of a string to - // resolv_set_nameservers_for_net() and make this double check unnecessary. - if (info->dnsrch_offset[i] < 0 || - ((size_t) info->dnsrch_offset[i]) >= sizeof(info->defdname) || !cur_domain[0]) { - break; - } - strlcpy(domains[i], cur_domain, MAXDNSRCHPATH); + + for (i = 0; i < static_cast<int>(info->search_domains.size()); i++) { + strlcpy(domains[i], info->search_domains[i].c_str(), MAXDNSRCHPATH); } *dcount = i; *params = info->params; diff --git a/resolv/res_init.cpp b/resolv/res_init.cpp index 04ac3aaa..6ce51403 100644 --- a/resolv/res_init.cpp +++ b/resolv/res_init.cpp @@ -125,11 +125,7 @@ int res_ninit(res_state statp) { /* This function has to be reachable by res_data.c but not publicly. */ int res_vinit(res_state statp, int preinit) { - char *cp, **pp; - char buf[BUFSIZ]; int nserv = 0; /* number of nameserver records read from file */ - int havesearch = 0; - int dots; sockaddr_union u[2]; if ((statp->options & RES_INIT) != 0U) res_ndestroy(statp); @@ -162,31 +158,6 @@ int res_vinit(res_state statp, int preinit) { statp->nsort = 0; res_setservers(statp, u, nserv); - if (statp->defdname[0] == 0 && gethostname(buf, sizeof(statp->defdname) - 1) == 0 && - (cp = strchr(buf, '.')) != NULL) - strcpy(statp->defdname, cp + 1); - - /* find components of local domain that might be searched */ - if (havesearch == 0) { - pp = statp->dnsrch; - *pp++ = statp->defdname; - *pp = NULL; - - dots = 0; - for (cp = statp->defdname; *cp; cp++) dots += (*cp == '.'); - - cp = statp->defdname; - while (pp < statp->dnsrch + MAXDFLSRCH) { - if (dots < LOCALDOMAINPARTS) break; - cp = strchr(cp, '.') + 1; /* we know there is one */ - *pp++ = cp; - dots--; - } - *pp = NULL; - LOG(DEBUG) << __func__ << ": dnsrch list:"; - for (pp = statp->dnsrch; *pp; pp++) LOG(DEBUG) << "\t" << *pp; - } - if (nserv > 0) { statp->nscount = nserv; statp->options |= RES_INIT; diff --git a/resolv/res_query.cpp b/resolv/res_query.cpp index 13696c48..37767725 100644 --- a/resolv/res_query.cpp +++ b/resolv/res_query.cpp @@ -207,7 +207,7 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */ { - const char *cp, *const *domain; + const char* cp; HEADER* hp = (HEADER*) (void*) answer; u_int dots; int ret, saved_herrno; @@ -251,12 +251,10 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ */ _resolv_populate_res_for_net(statp); - for (domain = (const char* const*) statp->dnsrch; *domain && !done; domain++) { + for (const auto& domain : statp->search_domains) { + if (domain == "." || domain == "") ++root_on_list; - if (domain[0][0] == '\0' || (domain[0][0] == '.' && domain[0][1] == '\0')) - root_on_list++; - - ret = res_nquerydomain(statp, name, *domain, cl, type, answer, anslen, herrno); + ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, anslen, herrno); if (ret > 0) return ret; /* @@ -295,7 +293,6 @@ int res_nsearch(res_state statp, const char* name, /* domain name */ /* anything else implies that we're done */ done++; } - } } diff --git a/resolv/resolv_cache.h b/resolv/resolv_cache.h index 246aa910..1a27e4b0 100644 --- a/resolv/resolv_cache.h +++ b/resolv/resolv_cache.h @@ -64,8 +64,12 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags); // Sets name servers for a given network. -int resolv_set_nameservers_for_net(unsigned netid, const char** servers, int numservers, - const char* domains, const res_params* params); +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const std::vector<std::string>& domains, const res_params* params); + +// TODO: remove it after updating all callers. +int resolv_set_nameservers(unsigned netid, const char** servers, int numservers, + const char* domains, const res_params* params); // Creates the cache associated with the given network. int resolv_create_cache_for_net(unsigned netid); diff --git a/resolv/resolv_private.h b/resolv/resolv_private.h index 45ae337b..1e10849b 100644 --- a/resolv/resolv_private.h +++ b/resolv/resolv_private.h @@ -59,6 +59,7 @@ #include <resolv.h> #include <time.h> #include <string> +#include <vector> #include "netd_resolv/params.h" #include "netd_resolv/resolv.h" @@ -76,9 +77,6 @@ /* * Global defines and variables for resolver stub. */ -#define MAXDFLSRCH 3 /* # default domain levels to try */ -#define LOCALDOMAINPARTS 2 /* min levels in name that is "local" */ - #define RES_TIMEOUT 5000 /* min. milliseconds between retries */ #define MAXRESOLVSORT 10 /* number of net to sort on */ #define RES_MAXNDOTS 15 /* should reflect bit field size */ @@ -88,17 +86,16 @@ struct res_state_ext; struct __res_state { - unsigned netid; /* NetId: cache key and socket mark */ - uid_t uid; /* uid of the app that sent the DNS lookup */ - u_long options; /* option flags - see below. */ - int nscount; /* number of name srvers */ - struct sockaddr_in nsaddr_list[MAXNS]; /* address of name server */ -#define nsaddr nsaddr_list[0] /* for backward compatibility */ - u_short id; /* current message id */ - char* dnsrch[MAXDNSRCH + 1]; /* components of domain to search */ - char defdname[256]; /* default domain (deprecated) */ - unsigned ndots : 4; /* threshold for initial abs. query */ - unsigned nsort : 4; /* number of elements in sort_list[] */ + unsigned netid; // NetId: cache key and socket mark + uid_t uid; // uid of the app that sent the DNS lookup + u_long options; // option flags - see below. + int nscount; // number of name srvers + struct sockaddr_in nsaddr_list[MAXNS]; // address of name server +#define nsaddr nsaddr_list[0] // for backward compatibility + u_short id; // current message id + std::vector<std::string> search_domains; // domains to search + unsigned ndots : 4; // threshold for initial abs. query + unsigned nsort : 4; // number of elements in sort_list[] char unused[3]; struct { struct in_addr addr; diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp index 6ca5e1a5..f2765fee 100644 --- a/resolv/resolver_test.cpp +++ b/resolv/resolver_test.cpp @@ -968,6 +968,92 @@ TEST_F(ResolverTest, SearchPathChange) { EXPECT_EQ("2001:db8::1:13", ToString(result)); } +namespace { + +std::vector<std::string> getResolverDomains(android::net::IDnsResolver* dnsResolverService, + unsigned netId) { + std::vector<std::string> res_servers; + std::vector<std::string> res_domains; + std::vector<std::string> res_tls_servers; + res_params res_params; + std::vector<ResolverStats> res_stats; + int wait_for_pending_req_timeout_count; + GetResolverInfo(dnsResolverService, netId, &res_servers, &res_domains, &res_tls_servers, + &res_params, &res_stats, &wait_for_pending_req_timeout_count); + return res_domains; +} + +} // namespace + +TEST_F(ResolverTest, SearchPathPrune) { + constexpr size_t DUPLICATED_DOMAIN_NUM = 3; + constexpr char listen_addr[] = "127.0.0.13"; + constexpr char domian_name1[] = "domain13.org."; + constexpr char domian_name2[] = "domain14.org."; + constexpr char host_name1[] = "test13.domain13.org."; + constexpr char host_name2[] = "test14.domain14.org."; + std::vector<std::string> servers = {listen_addr}; + + std::vector<std::string> testDomains1; + std::vector<std::string> testDomains2; + // Domain length should be <= 255 + // Max number of domains in search path is 6 + for (size_t i = 0; i < MAXDNSRCH + 1; i++) { + // Fill up with invalid domain + testDomains1.push_back(std::string(300, i + '0')); + // Fill up with valid but duplicated domain + testDomains2.push_back(StringPrintf("domain%zu.org", i % DUPLICATED_DOMAIN_NUM)); + } + + // Add valid domain used for query. + testDomains1.push_back(domian_name1); + + // Add valid domain twice used for query. + testDomains2.push_back(domian_name2); + testDomains2.push_back(domian_name2); + + const std::vector<DnsRecord> records = { + {host_name1, ns_type::ns_t_aaaa, "2001:db8::13"}, + {host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13"}, + }; + test::DNSResponder dns(listen_addr); + StartDns(dns, records); + ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains1)); + + const addrinfo hints = {.ai_family = AF_INET6}; + ScopedAddrinfo result = safe_getaddrinfo("test13", nullptr, &hints); + + EXPECT_TRUE(result != nullptr); + + EXPECT_EQ(1U, dns.queries().size()); + EXPECT_EQ(1U, GetNumQueries(dns, host_name1)); + EXPECT_EQ("2001:db8::13", ToString(result)); + + const auto& res_domains1 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID); + // Expect 1 valid domain, invalid domains are removed. + ASSERT_EQ(1U, res_domains1.size()); + EXPECT_EQ(domian_name1, res_domains1[0]); + + dns.clearQueries(); + + ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains2)); + + result = safe_getaddrinfo("test14", nullptr, &hints); + EXPECT_TRUE(result != nullptr); + + // (3 domains * 2 retries) + 1 success query = 7 + EXPECT_EQ(7U, dns.queries().size()); + EXPECT_EQ(1U, GetNumQueries(dns, host_name2)); + EXPECT_EQ("2001:db8::1:13", ToString(result)); + + const auto& res_domains2 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID); + // Expect 4 valid domain, duplicate domains are removed. + EXPECT_EQ(DUPLICATED_DOMAIN_NUM + 1U, res_domains2.size()); + EXPECT_THAT( + std::vector<std::string>({"domain0.org", "domain1.org", "domain2.org", domian_name2}), + testing::ElementsAreArray(res_domains2)); +} + static std::string base64Encode(const std::vector<uint8_t>& input) { size_t out_len; EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size())); |
