diff options
| -rw-r--r-- | server/dns/DnsTlsSocket.cpp | 51 | ||||
| -rw-r--r-- | server/dns/DnsTlsSocket.h | 17 | ||||
| -rw-r--r-- | tests/dns_tls_test.cpp | 41 |
3 files changed, 81 insertions, 28 deletions
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp index 8b2d2001..237d6bf2 100644 --- a/server/dns/DnsTlsSocket.cpp +++ b/server/dns/DnsTlsSocket.cpp @@ -350,6 +350,8 @@ void DnsTlsSocket::loop() { // If we have pending queries, wait for space to write one. // Otherwise, listen for new queries. + // Note: This blocks the destructor until q is empty, i.e. until all pending + // queries are sent or have failed to send. if (!q.empty()) { fds[SSLFD].events |= POLLOUT; } else { @@ -366,7 +368,7 @@ void DnsTlsSocket::loop() { ALOGV("Poll failed: %d", errno); break; } - if (fds[SSLFD].revents & (POLLIN | POLLERR)) { + if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) { if (!readResponse()) { ALOGV("SSL remote close or read error."); break; @@ -379,23 +381,17 @@ void DnsTlsSocket::loop() { ALOGW("Error during eventfd read"); break; } else if (res == 0) { - ALOGV("eventfd closed; disconnecting"); + ALOGW("eventfd closed; disconnecting"); break; } else if (res != sizeof(num_queries)) { ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries)); break; - } else if (num_queries <= 0) { - ALOGE("eventfd reads should always be positive"); + } else if (num_queries < 0) { + ALOGV("Negative eventfd read indicates destructor-initiated shutdown"); break; } // Take ownership of all pending queries. (q is always empty here.) mQueue.swap(q); - // The writing thread writes to mQueue and then increments mEventFd, so - // there should be at least num_queries entries in mQueue. - if (q.size() < (uint64_t) num_queries) { - ALOGE("Synchronization error"); - break; - } } else if (fds[SSLFD].revents & POLLOUT) { // q cannot be empty here. // Sending the entire queue here would risk a TCP flow control deadlock, so @@ -408,8 +404,6 @@ void DnsTlsSocket::loop() { q.pop_front(); } } - ALOGV("Closing event FD"); - mEventFd.reset(); ALOGV("Disconnecting"); sslDisconnect(); ALOGV("Calling onClosed"); @@ -420,12 +414,7 @@ void DnsTlsSocket::loop() { DnsTlsSocket::~DnsTlsSocket() { ALOGV("Destructor"); // This will trigger an orderly shutdown in loop(). - // In principle there is a data race here: If there is an I/O error in the network thread - // simultaneous with a call to the destructor in a different thread, both threads could - // attempt to call mEventFd.reset() at the same time. However, the implementation of - // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd - // after this point, so we don't expect an issue in practice. - mEventFd.reset(); + requestLoopShutdown(); { // Wait for the orderly shutdown to complete. std::lock_guard<std::mutex> guard(mLock); @@ -443,10 +432,6 @@ DnsTlsSocket::~DnsTlsSocket() { } bool DnsTlsSocket::query(uint16_t id, const Slice query) { - if (!mEventFd) { - return false; - } - // Compose the entire message in a single buffer, so that it can be // sent as a single TLS record. std::vector<uint8_t> buf(query.size() + 4); @@ -462,9 +447,25 @@ bool DnsTlsSocket::query(uint16_t id, const Slice query) { mQueue.push(std::move(buf)); // Increment the mEventFd counter by 1. - constexpr int64_t num_queries = 1; - int written = write(mEventFd.get(), &num_queries, sizeof(num_queries)); - return written == sizeof(num_queries); + return incrementEventFd(1); +} + +void DnsTlsSocket::requestLoopShutdown() { + // Write a negative number to the eventfd. This triggers an immediate shutdown. + incrementEventFd(INT64_MIN); +} + +bool DnsTlsSocket::incrementEventFd(const int64_t count) { + if (!mEventFd) { + ALOGV("eventfd is not initialized"); + return false; + } + int written = write(mEventFd.get(), &count, sizeof(count)); + if (written != sizeof(count)) { + ALOGE("Failed to increment eventfd by %" PRId64, count); + return false; + } + return true; } // Read exactly len bytes into buffer or fail with an SSL error code diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h index 2593bcf2..57e1acc7 100644 --- a/server/dns/DnsTlsSocket.h +++ b/server/dns/DnsTlsSocket.h @@ -65,7 +65,7 @@ public: // notified that the socket is closed. // Note that success here indicates successful sending, not receipt of a response. // Thread-safe. - bool query(uint16_t id, const Slice query) override; + bool query(uint16_t id, const Slice query) override EXCLUDES(mLock); private: // Lock to be held by the SSL event loop thread. This is not normally in contention. @@ -99,6 +99,15 @@ private: bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); bool readResponse() REQUIRES(mLock); + // Similar to query(), this function uses incrementEventFd to send a message to the + // loop thread. However, instead of incrementing the counter by one (indicating a + // new query), it wraps the counter to negative, which we use to indicate a shutdown + // request. + void requestLoopShutdown() EXCLUDES(mLock); + + // This function sends a message to the loop thread by incrementing mEventFd. + bool incrementEventFd(int64_t count) EXCLUDES(mLock); + // Queue of pending queries. query() pushes items onto the queue and notifies // the loop thread by incrementing mEventFd. loop() reads items off the queue. LockedQueue<std::vector<uint8_t>> mQueue; @@ -106,8 +115,10 @@ private: // eventfd socket used for notifying the SSL thread when queries are ready to send. // This socket acts similarly to an atomic counter, incremented by query() and cleared // by loop(). We have to use a socket because the SSL thread needs to wait in poll() - // for input from either a remote server or a query thread. - // EOF indicates a close request. + // for input from either a remote server or a query thread. Since eventfd does not have + // EOF, we indicate a close request by setting the counter to a negative number. + // This file descriptor is opened by initialize(), and closed implicitly after + // destruction. base::unique_fd mEventFd; // SSL Socket fields. diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp index bb5bfe56..b7fb3a45 100644 --- a/tests/dns_tls_test.cpp +++ b/tests/dns_tls_test.cpp @@ -28,6 +28,8 @@ #include "dns/IDnsTlsSocketFactory.h" #include "dns/IDnsTlsSocketObserver.h" +#include "dns_responder/dns_tls_frontend.h" + #include <chrono> #include <arpa/inet.h> #include <android-base/macros.h> @@ -871,5 +873,44 @@ TEST(QueryMapTest, FillHole) { EXPECT_FALSE(map.recordQuery(makeSlice(QUERY))); } +class StubObserver : public IDnsTlsSocketObserver { + public: + bool closed = false; + void onResponse(std::vector<uint8_t>) override {} + + void onClosed() override { closed = true; } +}; + +TEST(DnsTlsSocketTest, SlowDestructor) { + constexpr char tls_addr[] = "127.0.0.3"; + constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required. + // This test doesn't perform any queries, so the backend address can be invalid. + constexpr char backend_addr[] = "192.0.2.1"; + constexpr char backend_port[] = "1"; + + test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port); + ASSERT_TRUE(tls.startServer()); + + DnsTlsServer server; + parseServer(tls_addr, 8530, &server.ss); + + StubObserver observer; + ASSERT_FALSE(observer.closed); + DnsTlsSessionCache cache; + auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache); + ASSERT_TRUE(socket->initialize()); + + // Test: Time the socket destructor. This should be fast. + auto before = std::chrono::steady_clock::now(); + socket.reset(); + auto after = std::chrono::steady_clock::now(); + auto delay = after - before; + ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1}); + EXPECT_TRUE(observer.closed); + // Shutdown should complete in milliseconds, but if the shutdown signal is lost + // it will wait for the timeout, which is expected to take 20seconds. + EXPECT_LT(delay, std::chrono::seconds{5}); +} + } // end of namespace net } // end of namespace android |
