summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--resolv/ResolverController.cpp10
-rw-r--r--resolv/getaddrinfo.cpp6
-rw-r--r--resolv/libnetd_resolv_test.cpp66
-rw-r--r--resolv/res_cache.cpp94
-rw-r--r--resolv/res_init.cpp29
-rw-r--r--resolv/res_query.cpp11
-rw-r--r--resolv/resolv_cache.h8
-rw-r--r--resolv/resolv_private.h25
-rw-r--r--resolv/resolver_test.cpp86
9 files changed, 189 insertions, 146 deletions
diff --git a/resolv/ResolverController.cpp b/resolv/ResolverController.cpp
index ceb76393..7347a44a 100644
--- a/resolv/ResolverController.cpp
+++ b/resolv/ResolverController.cpp
@@ -228,9 +228,7 @@ int ResolverController::setResolverConfiguration(
server_ptrs.push_back(resolverParams.servers[i].c_str());
}
- std::string domains_str = android::base::Join(resolverParams.domains, " ");
-
- // TODO: Change resolv_set_nameservers_for_net() to use ResolverParamsParcel directly.
+ // TODO: Change resolv_set_nameservers() to use ResolverParamsParcel directly.
res_params res_params = {};
res_params.sample_validity = resolverParams.sampleValiditySeconds;
res_params.success_threshold = resolverParams.successThreshold;
@@ -240,10 +238,10 @@ int ResolverController::setResolverConfiguration(
res_params.retry_count = resolverParams.retryCount;
LOG(VERBOSE) << "setDnsServers netId = " << resolverParams.netId
- << ", numservers = " << resolverParams.domains.size();
+ << ", numservers = " << resolverParams.servers.size();
- return -resolv_set_nameservers_for_net(resolverParams.netId, server_ptrs.data(),
- server_ptrs.size(), domains_str.c_str(), &res_params);
+ return -resolv_set_nameservers(resolverParams.netId, server_ptrs.data(), server_ptrs.size(),
+ resolverParams.domains, &res_params);
}
int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* servers,
diff --git a/resolv/getaddrinfo.cpp b/resolv/getaddrinfo.cpp
index 592b5a3c..9f43ec8b 100644
--- a/resolv/getaddrinfo.cpp
+++ b/resolv/getaddrinfo.cpp
@@ -1679,7 +1679,7 @@ static int res_queryN(const char* name, res_target* target, res_state res, int*
* is detected. Error code, if any, is left in *herrno.
*/
static int res_searchN(const char* name, res_target* target, res_state res, int* herrno) {
- const char *cp, *const *domain;
+ const char* cp;
HEADER* hp;
u_int dots;
int ret, saved_herrno;
@@ -1722,8 +1722,8 @@ static int res_searchN(const char* name, res_target* target, res_state res, int*
*/
_resolv_populate_res_for_net(res);
- for (domain = (const char* const*) res->dnsrch; *domain && !done; domain++) {
- ret = res_querydomainN(name, *domain, target, res, herrno);
+ for (const auto& domain : res->search_domains) {
+ ret = res_querydomainN(name, domain.c_str(), target, res, herrno);
if (ret > 0) return ret;
/*
diff --git a/resolv/libnetd_resolv_test.cpp b/resolv/libnetd_resolv_test.cpp
index 29984eb7..d6fd80d8 100644
--- a/resolv/libnetd_resolv_test.cpp
+++ b/resolv/libnetd_resolv_test.cpp
@@ -368,8 +368,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname_NoData) {
dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
dns.clearQueries();
// Want AAAA answer but DNS server has A answer only.
@@ -395,8 +395,8 @@ TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname) {
dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
int ai_family;
@@ -430,8 +430,8 @@ TEST_F(ResolvGetAddrInfoTest, IllegalHostname) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
// Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
static constexpr char const* illegalHostnames[] = {
@@ -497,8 +497,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerResponseError) {
dns.setResponseProbability(0.0); // always ignore requests and response preset rcode
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
addrinfo* result = nullptr;
const addrinfo hints = {.ai_family = AF_UNSPEC};
@@ -519,8 +519,8 @@ TEST_F(ResolvGetAddrInfoTest, ServerTimeout) {
dns.setResponseProbability(0.0); // always ignore requests and don't response
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
addrinfo* result = nullptr;
const addrinfo hints = {.ai_family = AF_UNSPEC};
@@ -543,8 +543,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesNoIpAddress) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
const char* name;
@@ -582,8 +582,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesBrokenChainByIllegalCname) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
const char* name;
@@ -642,8 +642,8 @@ TEST_F(ResolvGetAddrInfoTest, CnamesInfiniteLoop) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) {
SCOPED_TRACE(StringPrintf("family: %d", family));
@@ -670,8 +670,8 @@ TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) {
dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
int ai_family;
@@ -704,8 +704,8 @@ TEST_F(GetHostByNameForNetContextTest, IllegalHostname) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
// Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
static constexpr char const* illegalHostnames[] = {
@@ -749,8 +749,8 @@ TEST_F(GetHostByNameForNetContextTest, NoData) {
dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
dns.clearQueries();
// Want AAAA answer but DNS server has A answer only.
@@ -793,8 +793,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerResponseError) {
dns.setResponseProbability(0.0); // always ignore requests and response preset rcode
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
hostent* hp = nullptr;
NetworkDnsEventReported event;
@@ -814,8 +814,8 @@ TEST_F(GetHostByNameForNetContextTest, ServerTimeout) {
dns.setResponseProbability(0.0); // always ignore requests and don't response
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
hostent* hp = nullptr;
NetworkDnsEventReported event;
@@ -836,8 +836,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesNoIpAddress) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
const char* name;
@@ -870,8 +870,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesBrokenChainByIllegalCname) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
static const struct TestConfig {
const char* name;
@@ -929,8 +929,8 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) {
ASSERT_TRUE(dns.startServer());
const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers_for_net(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
+ mDefaultSearchDomains, &mDefaultParams_Binder));
for (const auto& family : {AF_INET, AF_INET6}) {
SCOPED_TRACE(StringPrintf("family: %d", family));
@@ -946,7 +946,7 @@ TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) {
// Note that local host file function, files_getaddrinfo(), of resolv_getaddrinfo()
// is not tested because it only returns a boolean (success or failure) without any error number.
-// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers_for_net, as
+// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers, as
// ResolverTest does.
// TODO: Add test for resolv_getaddrinfo().
// - DNS response message parsing.
diff --git a/resolv/res_cache.cpp b/resolv/res_cache.cpp
index abf90110..befcbed8 100644
--- a/resolv/res_cache.cpp
+++ b/resolv/res_cache.cpp
@@ -35,7 +35,10 @@
#include <stdlib.h>
#include <string.h>
#include <time.h>
+#include <algorithm>
#include <mutex>
+#include <set>
+#include <vector>
#include <arpa/inet.h>
#include <arpa/nameser.h>
@@ -46,6 +49,7 @@
#include <android-base/logging.h>
#include <android-base/parseint.h>
+#include <android-base/strings.h>
#include <android-base/thread_annotations.h>
#include <android/multinetwork.h> // ResNsendFlags
@@ -1142,8 +1146,7 @@ struct resolv_cache_info {
int revision_id; // # times the nameservers have been replaced
res_params params;
struct res_stats nsstats[MAXNS];
- char defdname[MAXDNSRCHPATH];
- int dnsrch_offset[MAXDNSRCH + 1]; // offsets into defdname
+ std::vector<std::string> search_domains;
int wait_for_pending_req_timeout_count;
};
@@ -1719,12 +1722,35 @@ static void resolv_set_experiment_params(res_params* params) {
}
}
-int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const int numservers,
- const char* domains, const res_params* params) {
- char* cp;
- int* offset;
- struct addrinfo* nsaddrinfo[MAXNS];
+int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
+ const char* domains, const res_params* params) {
+ return resolv_set_nameservers(netid, servers, numservers, android::base::Split(domains, " "),
+ params);
+}
+
+namespace {
+// Returns valid domains without duplicates which are limited to max size |MAXDNSRCH|.
+std::vector<std::string> filter_domains(const std::vector<std::string>& domains) {
+ std::set<std::string> tmp_set;
+ std::vector<std::string> res;
+
+ std::copy_if(domains.begin(), domains.end(), std::back_inserter(res),
+ [&tmp_set](const std::string& str) {
+ return !(str.size() > MAXDNSRCHPATH - 1) && (tmp_set.insert(str).second);
+ });
+ if (res.size() > MAXDNSRCH) {
+ LOG(WARNING) << __func__ << ": valid domains=" << res.size()
+ << ", but MAXDNSRCH=" << MAXDNSRCH;
+ res.resize(MAXDNSRCH);
+ }
+ return res;
+}
+
+} // namespace
+
+int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
+ const std::vector<std::string>& domains, const res_params* params) {
if (numservers > MAXNS) {
LOG(ERROR) << __func__ << ": numservers=" << numservers << ", MAXNS=" << MAXNS;
return E2BIG;
@@ -1732,6 +1758,8 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i
// Parse the addresses before actually locking or changing any state, in case there is an error.
// As a side effect this also reduces the time the lock is kept.
+ // TODO: find a better way to replace addrinfo*, something like std::vector<SafeAddrinfo>
+ addrinfo* nsaddrinfo[MAXNS];
char sbuf[NI_MAXSERV];
snprintf(sbuf, sizeof(sbuf), "%u", NAMESERVER_PORT);
for (int i = 0; i < numservers; i++) {
@@ -1795,31 +1823,9 @@ int resolv_set_nameservers_for_net(unsigned netid, const char** servers, const i
}
}
- // Always update the search paths, since determining whether they actually changed is
- // complex due to the zero-padding, and probably not worth the effort. Cache-flushing
- // however is not necessary, since the stored cache entries do contain the domain, not
- // just the host name.
- strlcpy(cache_info->defdname, domains, sizeof(cache_info->defdname));
- if ((cp = strchr(cache_info->defdname, '\n')) != NULL) *cp = '\0';
- LOG(INFO) << __func__ << ": domains=\"" << cache_info->defdname << "\"";
-
- cp = cache_info->defdname;
- offset = cache_info->dnsrch_offset;
- while (offset < cache_info->dnsrch_offset + MAXDNSRCH) {
- while (*cp == ' ' || *cp == '\t') /* skip leading white space */
- cp++;
- if (*cp == '\0') /* stop if nothing more to do */
- break;
- *offset++ = cp - cache_info->defdname; /* record this search domain */
- while (*cp) { /* zero-terminate it */
- if (*cp == ' ' || *cp == '\t') {
- *cp++ = '\0';
- break;
- }
- cp++;
- }
- }
- *offset = -1; /* cache_info->dnsrch_offset has MAXDNSRCH+1 items */
+ // Always update the search paths. Cache-flushing however is not necessary,
+ // since the stored cache entries do contain the domain, not just the host name.
+ cache_info->search_domains = filter_domains(domains);
return 0;
}
@@ -1899,15 +1905,7 @@ void _resolv_populate_res_for_net(res_state statp) {
}
}
statp->nscount = nserv;
- // now do search domains. Note that we cache the offsets as this code runs alot
- // but the setting/offset-computer only runs when set/changed
- // WARNING: Don't use str*cpy() here, this string contains zeroes.
- memcpy(statp->defdname, info->defdname, sizeof(statp->defdname));
- char** pp = statp->dnsrch;
- int* p = info->dnsrch_offset;
- while (pp < statp->dnsrch + MAXDNSRCH && *p != -1) {
- *pp++ = &statp->defdname[0] + *p++;
- }
+ statp->search_domains = info->search_domains;
}
}
@@ -1979,17 +1977,9 @@ int android_net_res_stats_get_info_for_net(unsigned netid, int* nscount,
memcpy(&servers[i], info->nsaddrinfo[i]->ai_addr, info->nsaddrinfo[i]->ai_addrlen);
stats[i] = info->nsstats[i];
}
- for (i = 0; i < MAXDNSRCH; i++) {
- const char* cur_domain = info->defdname + info->dnsrch_offset[i];
- // dnsrch_offset[i] can either be -1 or point to an empty string to indicate the end
- // of the search offsets. Checking for < 0 is not strictly necessary, but safer.
- // TODO: Pass in a search domain array instead of a string to
- // resolv_set_nameservers_for_net() and make this double check unnecessary.
- if (info->dnsrch_offset[i] < 0 ||
- ((size_t) info->dnsrch_offset[i]) >= sizeof(info->defdname) || !cur_domain[0]) {
- break;
- }
- strlcpy(domains[i], cur_domain, MAXDNSRCHPATH);
+
+ for (i = 0; i < static_cast<int>(info->search_domains.size()); i++) {
+ strlcpy(domains[i], info->search_domains[i].c_str(), MAXDNSRCHPATH);
}
*dcount = i;
*params = info->params;
diff --git a/resolv/res_init.cpp b/resolv/res_init.cpp
index 04ac3aaa..6ce51403 100644
--- a/resolv/res_init.cpp
+++ b/resolv/res_init.cpp
@@ -125,11 +125,7 @@ int res_ninit(res_state statp) {
/* This function has to be reachable by res_data.c but not publicly. */
int res_vinit(res_state statp, int preinit) {
- char *cp, **pp;
- char buf[BUFSIZ];
int nserv = 0; /* number of nameserver records read from file */
- int havesearch = 0;
- int dots;
sockaddr_union u[2];
if ((statp->options & RES_INIT) != 0U) res_ndestroy(statp);
@@ -162,31 +158,6 @@ int res_vinit(res_state statp, int preinit) {
statp->nsort = 0;
res_setservers(statp, u, nserv);
- if (statp->defdname[0] == 0 && gethostname(buf, sizeof(statp->defdname) - 1) == 0 &&
- (cp = strchr(buf, '.')) != NULL)
- strcpy(statp->defdname, cp + 1);
-
- /* find components of local domain that might be searched */
- if (havesearch == 0) {
- pp = statp->dnsrch;
- *pp++ = statp->defdname;
- *pp = NULL;
-
- dots = 0;
- for (cp = statp->defdname; *cp; cp++) dots += (*cp == '.');
-
- cp = statp->defdname;
- while (pp < statp->dnsrch + MAXDFLSRCH) {
- if (dots < LOCALDOMAINPARTS) break;
- cp = strchr(cp, '.') + 1; /* we know there is one */
- *pp++ = cp;
- dots--;
- }
- *pp = NULL;
- LOG(DEBUG) << __func__ << ": dnsrch list:";
- for (pp = statp->dnsrch; *pp; pp++) LOG(DEBUG) << "\t" << *pp;
- }
-
if (nserv > 0) {
statp->nscount = nserv;
statp->options |= RES_INIT;
diff --git a/resolv/res_query.cpp b/resolv/res_query.cpp
index 13696c48..37767725 100644
--- a/resolv/res_query.cpp
+++ b/resolv/res_query.cpp
@@ -207,7 +207,7 @@ int res_nsearch(res_state statp, const char* name, /* domain name */
int* herrno) /* legacy and extended
h_errno NETD_RESOLV_H_ERRNO_EXT_* */
{
- const char *cp, *const *domain;
+ const char* cp;
HEADER* hp = (HEADER*) (void*) answer;
u_int dots;
int ret, saved_herrno;
@@ -251,12 +251,10 @@ int res_nsearch(res_state statp, const char* name, /* domain name */
*/
_resolv_populate_res_for_net(statp);
- for (domain = (const char* const*) statp->dnsrch; *domain && !done; domain++) {
+ for (const auto& domain : statp->search_domains) {
+ if (domain == "." || domain == "") ++root_on_list;
- if (domain[0][0] == '\0' || (domain[0][0] == '.' && domain[0][1] == '\0'))
- root_on_list++;
-
- ret = res_nquerydomain(statp, name, *domain, cl, type, answer, anslen, herrno);
+ ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, anslen, herrno);
if (ret > 0) return ret;
/*
@@ -295,7 +293,6 @@ int res_nsearch(res_state statp, const char* name, /* domain name */
/* anything else implies that we're done */
done++;
}
-
}
}
diff --git a/resolv/resolv_cache.h b/resolv/resolv_cache.h
index 246aa910..1a27e4b0 100644
--- a/resolv/resolv_cache.h
+++ b/resolv/resolv_cache.h
@@ -64,8 +64,12 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void
void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags);
// Sets name servers for a given network.
-int resolv_set_nameservers_for_net(unsigned netid, const char** servers, int numservers,
- const char* domains, const res_params* params);
+int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
+ const std::vector<std::string>& domains, const res_params* params);
+
+// TODO: remove it after updating all callers.
+int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
+ const char* domains, const res_params* params);
// Creates the cache associated with the given network.
int resolv_create_cache_for_net(unsigned netid);
diff --git a/resolv/resolv_private.h b/resolv/resolv_private.h
index 45ae337b..1e10849b 100644
--- a/resolv/resolv_private.h
+++ b/resolv/resolv_private.h
@@ -59,6 +59,7 @@
#include <resolv.h>
#include <time.h>
#include <string>
+#include <vector>
#include "netd_resolv/params.h"
#include "netd_resolv/resolv.h"
@@ -76,9 +77,6 @@
/*
* Global defines and variables for resolver stub.
*/
-#define MAXDFLSRCH 3 /* # default domain levels to try */
-#define LOCALDOMAINPARTS 2 /* min levels in name that is "local" */
-
#define RES_TIMEOUT 5000 /* min. milliseconds between retries */
#define MAXRESOLVSORT 10 /* number of net to sort on */
#define RES_MAXNDOTS 15 /* should reflect bit field size */
@@ -88,17 +86,16 @@
struct res_state_ext;
struct __res_state {
- unsigned netid; /* NetId: cache key and socket mark */
- uid_t uid; /* uid of the app that sent the DNS lookup */
- u_long options; /* option flags - see below. */
- int nscount; /* number of name srvers */
- struct sockaddr_in nsaddr_list[MAXNS]; /* address of name server */
-#define nsaddr nsaddr_list[0] /* for backward compatibility */
- u_short id; /* current message id */
- char* dnsrch[MAXDNSRCH + 1]; /* components of domain to search */
- char defdname[256]; /* default domain (deprecated) */
- unsigned ndots : 4; /* threshold for initial abs. query */
- unsigned nsort : 4; /* number of elements in sort_list[] */
+ unsigned netid; // NetId: cache key and socket mark
+ uid_t uid; // uid of the app that sent the DNS lookup
+ u_long options; // option flags - see below.
+ int nscount; // number of name srvers
+ struct sockaddr_in nsaddr_list[MAXNS]; // address of name server
+#define nsaddr nsaddr_list[0] // for backward compatibility
+ u_short id; // current message id
+ std::vector<std::string> search_domains; // domains to search
+ unsigned ndots : 4; // threshold for initial abs. query
+ unsigned nsort : 4; // number of elements in sort_list[]
char unused[3];
struct {
struct in_addr addr;
diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp
index 6ca5e1a5..f2765fee 100644
--- a/resolv/resolver_test.cpp
+++ b/resolv/resolver_test.cpp
@@ -968,6 +968,92 @@ TEST_F(ResolverTest, SearchPathChange) {
EXPECT_EQ("2001:db8::1:13", ToString(result));
}
+namespace {
+
+std::vector<std::string> getResolverDomains(android::net::IDnsResolver* dnsResolverService,
+ unsigned netId) {
+ std::vector<std::string> res_servers;
+ std::vector<std::string> res_domains;
+ std::vector<std::string> res_tls_servers;
+ res_params res_params;
+ std::vector<ResolverStats> res_stats;
+ int wait_for_pending_req_timeout_count;
+ GetResolverInfo(dnsResolverService, netId, &res_servers, &res_domains, &res_tls_servers,
+ &res_params, &res_stats, &wait_for_pending_req_timeout_count);
+ return res_domains;
+}
+
+} // namespace
+
+TEST_F(ResolverTest, SearchPathPrune) {
+ constexpr size_t DUPLICATED_DOMAIN_NUM = 3;
+ constexpr char listen_addr[] = "127.0.0.13";
+ constexpr char domian_name1[] = "domain13.org.";
+ constexpr char domian_name2[] = "domain14.org.";
+ constexpr char host_name1[] = "test13.domain13.org.";
+ constexpr char host_name2[] = "test14.domain14.org.";
+ std::vector<std::string> servers = {listen_addr};
+
+ std::vector<std::string> testDomains1;
+ std::vector<std::string> testDomains2;
+ // Domain length should be <= 255
+ // Max number of domains in search path is 6
+ for (size_t i = 0; i < MAXDNSRCH + 1; i++) {
+ // Fill up with invalid domain
+ testDomains1.push_back(std::string(300, i + '0'));
+ // Fill up with valid but duplicated domain
+ testDomains2.push_back(StringPrintf("domain%zu.org", i % DUPLICATED_DOMAIN_NUM));
+ }
+
+ // Add valid domain used for query.
+ testDomains1.push_back(domian_name1);
+
+ // Add valid domain twice used for query.
+ testDomains2.push_back(domian_name2);
+ testDomains2.push_back(domian_name2);
+
+ const std::vector<DnsRecord> records = {
+ {host_name1, ns_type::ns_t_aaaa, "2001:db8::13"},
+ {host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13"},
+ };
+ test::DNSResponder dns(listen_addr);
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains1));
+
+ const addrinfo hints = {.ai_family = AF_INET6};
+ ScopedAddrinfo result = safe_getaddrinfo("test13", nullptr, &hints);
+
+ EXPECT_TRUE(result != nullptr);
+
+ EXPECT_EQ(1U, dns.queries().size());
+ EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
+ EXPECT_EQ("2001:db8::13", ToString(result));
+
+ const auto& res_domains1 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
+ // Expect 1 valid domain, invalid domains are removed.
+ ASSERT_EQ(1U, res_domains1.size());
+ EXPECT_EQ(domian_name1, res_domains1[0]);
+
+ dns.clearQueries();
+
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork(servers, testDomains2));
+
+ result = safe_getaddrinfo("test14", nullptr, &hints);
+ EXPECT_TRUE(result != nullptr);
+
+ // (3 domains * 2 retries) + 1 success query = 7
+ EXPECT_EQ(7U, dns.queries().size());
+ EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
+ EXPECT_EQ("2001:db8::1:13", ToString(result));
+
+ const auto& res_domains2 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
+ // Expect 4 valid domain, duplicate domains are removed.
+ EXPECT_EQ(DUPLICATED_DOMAIN_NUM + 1U, res_domains2.size());
+ EXPECT_THAT(
+ std::vector<std::string>({"domain0.org", "domain1.org", "domain2.org", domian_name2}),
+ testing::ElementsAreArray(res_domains2));
+}
+
static std::string base64Encode(const std::vector<uint8_t>& input) {
size_t out_len;
EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));