summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/dns/DnsTlsSocket.cpp51
-rw-r--r--server/dns/DnsTlsSocket.h17
-rw-r--r--tests/dns_tls_test.cpp41
3 files changed, 81 insertions, 28 deletions
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp
index 8b2d2001..237d6bf2 100644
--- a/server/dns/DnsTlsSocket.cpp
+++ b/server/dns/DnsTlsSocket.cpp
@@ -350,6 +350,8 @@ void DnsTlsSocket::loop() {
// If we have pending queries, wait for space to write one.
// Otherwise, listen for new queries.
+ // Note: This blocks the destructor until q is empty, i.e. until all pending
+ // queries are sent or have failed to send.
if (!q.empty()) {
fds[SSLFD].events |= POLLOUT;
} else {
@@ -366,7 +368,7 @@ void DnsTlsSocket::loop() {
ALOGV("Poll failed: %d", errno);
break;
}
- if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
+ if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
if (!readResponse()) {
ALOGV("SSL remote close or read error.");
break;
@@ -379,23 +381,17 @@ void DnsTlsSocket::loop() {
ALOGW("Error during eventfd read");
break;
} else if (res == 0) {
- ALOGV("eventfd closed; disconnecting");
+ ALOGW("eventfd closed; disconnecting");
break;
} else if (res != sizeof(num_queries)) {
ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
break;
- } else if (num_queries <= 0) {
- ALOGE("eventfd reads should always be positive");
+ } else if (num_queries < 0) {
+ ALOGV("Negative eventfd read indicates destructor-initiated shutdown");
break;
}
// Take ownership of all pending queries. (q is always empty here.)
mQueue.swap(q);
- // The writing thread writes to mQueue and then increments mEventFd, so
- // there should be at least num_queries entries in mQueue.
- if (q.size() < (uint64_t) num_queries) {
- ALOGE("Synchronization error");
- break;
- }
} else if (fds[SSLFD].revents & POLLOUT) {
// q cannot be empty here.
// Sending the entire queue here would risk a TCP flow control deadlock, so
@@ -408,8 +404,6 @@ void DnsTlsSocket::loop() {
q.pop_front();
}
}
- ALOGV("Closing event FD");
- mEventFd.reset();
ALOGV("Disconnecting");
sslDisconnect();
ALOGV("Calling onClosed");
@@ -420,12 +414,7 @@ void DnsTlsSocket::loop() {
DnsTlsSocket::~DnsTlsSocket() {
ALOGV("Destructor");
// This will trigger an orderly shutdown in loop().
- // In principle there is a data race here: If there is an I/O error in the network thread
- // simultaneous with a call to the destructor in a different thread, both threads could
- // attempt to call mEventFd.reset() at the same time. However, the implementation of
- // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
- // after this point, so we don't expect an issue in practice.
- mEventFd.reset();
+ requestLoopShutdown();
{
// Wait for the orderly shutdown to complete.
std::lock_guard<std::mutex> guard(mLock);
@@ -443,10 +432,6 @@ DnsTlsSocket::~DnsTlsSocket() {
}
bool DnsTlsSocket::query(uint16_t id, const Slice query) {
- if (!mEventFd) {
- return false;
- }
-
// Compose the entire message in a single buffer, so that it can be
// sent as a single TLS record.
std::vector<uint8_t> buf(query.size() + 4);
@@ -462,9 +447,25 @@ bool DnsTlsSocket::query(uint16_t id, const Slice query) {
mQueue.push(std::move(buf));
// Increment the mEventFd counter by 1.
- constexpr int64_t num_queries = 1;
- int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
- return written == sizeof(num_queries);
+ return incrementEventFd(1);
+}
+
+void DnsTlsSocket::requestLoopShutdown() {
+ // Write a negative number to the eventfd. This triggers an immediate shutdown.
+ incrementEventFd(INT64_MIN);
+}
+
+bool DnsTlsSocket::incrementEventFd(const int64_t count) {
+ if (!mEventFd) {
+ ALOGV("eventfd is not initialized");
+ return false;
+ }
+ int written = write(mEventFd.get(), &count, sizeof(count));
+ if (written != sizeof(count)) {
+ ALOGE("Failed to increment eventfd by %" PRId64, count);
+ return false;
+ }
+ return true;
}
// Read exactly len bytes into buffer or fail with an SSL error code
diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h
index 2593bcf2..57e1acc7 100644
--- a/server/dns/DnsTlsSocket.h
+++ b/server/dns/DnsTlsSocket.h
@@ -65,7 +65,7 @@ public:
// notified that the socket is closed.
// Note that success here indicates successful sending, not receipt of a response.
// Thread-safe.
- bool query(uint16_t id, const Slice query) override;
+ bool query(uint16_t id, const Slice query) override EXCLUDES(mLock);
private:
// Lock to be held by the SSL event loop thread. This is not normally in contention.
@@ -99,6 +99,15 @@ private:
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
bool readResponse() REQUIRES(mLock);
+ // Similar to query(), this function uses incrementEventFd to send a message to the
+ // loop thread. However, instead of incrementing the counter by one (indicating a
+ // new query), it wraps the counter to negative, which we use to indicate a shutdown
+ // request.
+ void requestLoopShutdown() EXCLUDES(mLock);
+
+ // This function sends a message to the loop thread by incrementing mEventFd.
+ bool incrementEventFd(int64_t count) EXCLUDES(mLock);
+
// Queue of pending queries. query() pushes items onto the queue and notifies
// the loop thread by incrementing mEventFd. loop() reads items off the queue.
LockedQueue<std::vector<uint8_t>> mQueue;
@@ -106,8 +115,10 @@ private:
// eventfd socket used for notifying the SSL thread when queries are ready to send.
// This socket acts similarly to an atomic counter, incremented by query() and cleared
// by loop(). We have to use a socket because the SSL thread needs to wait in poll()
- // for input from either a remote server or a query thread.
- // EOF indicates a close request.
+ // for input from either a remote server or a query thread. Since eventfd does not have
+ // EOF, we indicate a close request by setting the counter to a negative number.
+ // This file descriptor is opened by initialize(), and closed implicitly after
+ // destruction.
base::unique_fd mEventFd;
// SSL Socket fields.
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
index bb5bfe56..b7fb3a45 100644
--- a/tests/dns_tls_test.cpp
+++ b/tests/dns_tls_test.cpp
@@ -28,6 +28,8 @@
#include "dns/IDnsTlsSocketFactory.h"
#include "dns/IDnsTlsSocketObserver.h"
+#include "dns_responder/dns_tls_frontend.h"
+
#include <chrono>
#include <arpa/inet.h>
#include <android-base/macros.h>
@@ -871,5 +873,44 @@ TEST(QueryMapTest, FillHole) {
EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
}
+class StubObserver : public IDnsTlsSocketObserver {
+ public:
+ bool closed = false;
+ void onResponse(std::vector<uint8_t>) override {}
+
+ void onClosed() override { closed = true; }
+};
+
+TEST(DnsTlsSocketTest, SlowDestructor) {
+ constexpr char tls_addr[] = "127.0.0.3";
+ constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
+ // This test doesn't perform any queries, so the backend address can be invalid.
+ constexpr char backend_addr[] = "192.0.2.1";
+ constexpr char backend_port[] = "1";
+
+ test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
+ ASSERT_TRUE(tls.startServer());
+
+ DnsTlsServer server;
+ parseServer(tls_addr, 8530, &server.ss);
+
+ StubObserver observer;
+ ASSERT_FALSE(observer.closed);
+ DnsTlsSessionCache cache;
+ auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
+ ASSERT_TRUE(socket->initialize());
+
+ // Test: Time the socket destructor. This should be fast.
+ auto before = std::chrono::steady_clock::now();
+ socket.reset();
+ auto after = std::chrono::steady_clock::now();
+ auto delay = after - before;
+ ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
+ EXPECT_TRUE(observer.closed);
+ // Shutdown should complete in milliseconds, but if the shutdown signal is lost
+ // it will wait for the timeout, which is expected to take 20seconds.
+ EXPECT_LT(delay, std::chrono::seconds{5});
+}
+
} // end of namespace net
} // end of namespace android