/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "dns_tls_frontend.h" #include <arpa/inet.h> #include <netdb.h> #include <openssl/err.h> #include <openssl/evp.h> #include <openssl/ssl.h> #include <sys/eventfd.h> #include <sys/poll.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> #define LOG_TAG "DnsTlsFrontend" #include <log/log.h> #include <netdutils/SocketOption.h> #include "NetdConstants.h" // SHA256_SIZE using android::netdutils::enableSockopt; namespace { // Copied from DnsTlsTransport. bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) { int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr); unsigned char spki[spki_len]; unsigned char* temp = spki; if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { ALOGE("SPKI length mismatch"); return false; } out->resize(SHA256_SIZE); unsigned int digest_len = 0; int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr); if (ret != 1) { ALOGE("Server cert digest extraction failed"); return false; } if (digest_len != out->size()) { ALOGE("Wrong digest length: %d", digest_len); return false; } return true; } std::string errno2str() { char error_msg[512] = { 0 }; return strerror_r(errno, error_msg, sizeof(error_msg)); } #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str()) std::string addr2str(const sockaddr* sa, socklen_t sa_len) { char host_str[NI_MAXHOST] = { 0 }; int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST); if (rv == 0) return std::string(host_str); return std::string(); } bssl::UniquePtr<EVP_PKEY> make_private_key() { bssl::UniquePtr<BIGNUM> e(BN_new()); if (!e) { ALOGE("BN_new failed"); return nullptr; } if (!BN_set_word(e.get(), RSA_F4)) { ALOGE("BN_set_word failed"); return nullptr; } bssl::UniquePtr<RSA> rsa(RSA_new()); if (!rsa) { ALOGE("RSA_new failed"); return nullptr; } if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr)) { ALOGE("RSA_generate_key_ex failed"); return nullptr; } bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new()); if (!privkey) { ALOGE("EVP_PKEY_new failed"); return nullptr; } if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) { ALOGE("EVP_PKEY_assign_RSA failed"); return nullptr; } // |rsa| is now owned by |privkey|, so no need to free it. rsa.release(); return privkey; } bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) { bssl::UniquePtr<X509> cert(X509_new()); if (!cert) { ALOGE("X509_new failed"); return nullptr; } ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1); // Set one hour expiration. X509_gmtime_adj(X509_get_notBefore(cert.get()), 0); X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60); X509_set_pubkey(cert.get(), privkey); if (!X509_sign(cert.get(), parent_key, EVP_sha256())) { ALOGE("X509_sign failed"); return nullptr; } return cert; } } namespace test { bool DnsTlsFrontend::startServer() { SSL_load_error_strings(); OpenSSL_add_ssl_algorithms(); // reset queries_ to 0 every time startServer called // which would help us easy to check queries_ via calling waitForQueries queries_ = 0; ctx_.reset(SSL_CTX_new(TLS_server_method())); if (!ctx_) { ALOGE("SSL context creation failed"); return false; } SSL_CTX_set_ecdh_auto(ctx_.get(), 1); // Make certificate chain std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_); for (int i = 0; i < chain_length_; ++i) { keys[i] = make_private_key(); } std::vector<bssl::UniquePtr<X509>> certs(chain_length_); for (int i = 0; i < chain_length_; ++i) { int next = std::min(i + 1, chain_length_ - 1); certs[i] = make_cert(keys[i].get(), keys[next].get()); } // Install certificate chain. if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) { ALOGE("SSL_CTX_use_certificate failed"); return false; } if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) { ALOGE("SSL_CTX_use_PrivateKey failed"); return false; } for (int i = 1; i < chain_length_; ++i) { if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) { ALOGE("SSL_CTX_add1_chain_cert failed"); return false; } } // Report the fingerprint of the "middle" cert. For N = 2, this is the root. int fp_index = chain_length_ / 2; if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) { ALOGE("getSPKIDigest failed"); return false; } // Set up TCP server socket for clients. addrinfo frontend_ai_hints{ .ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM, .ai_flags = AI_PASSIVE }; addrinfo* frontend_ai_res = nullptr; int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints, &frontend_ai_res); ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res); if (rv) { ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), listen_service_.c_str(), gai_strerror(rv)); return false; } for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) { android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol)); if (s.get() < 0) { APLOGI("ignore creating socket failed %d", s.get()); continue; } enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError(); enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError(); std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen); if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) { APLOGI("failed to bind TCP %s:%s", host_str.c_str(), listen_service_.c_str()); continue; } ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str()); socket_ = std::move(s); break; } if (listen(socket_.get(), 1) < 0) { APLOGI("failed to listen socket %d", socket_.get()); return false; } // Set up UDP client socket to backend. addrinfo backend_ai_hints{ .ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM }; addrinfo* backend_ai_res = nullptr; rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints, &backend_ai_res); ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res); if (rv) { ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), listen_service_.c_str(), gai_strerror(rv)); return false; } backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype, backend_ai_res->ai_protocol)); if (backend_socket_.get() < 0) { APLOGI("backend socket %d creation failed", backend_socket_.get()); return false; } // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of // no backend server. Don't check it. connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen); // Set up eventfd socket. event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); if (event_fd_.get() == -1) { APLOGI("failed to create eventfd %d", event_fd_.get()); return false; } { std::lock_guard lock(update_mutex_); handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this); } ALOGI("server started successfully"); return true; } void DnsTlsFrontend::requestHandler() { ALOGD("Request handler started"); enum { EVENT_FD = 0, LISTEN_FD = 1 }; pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN}, {.fd = socket_.get(), .events = POLLIN}}; while (true) { int poll_code = poll(fds, std::size(fds), -1); if (poll_code <= 0) { APLOGI("Poll failed with error %d", poll_code); break; } if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) { handleEventFd(); break; } if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) { sockaddr_storage addr; socklen_t len = sizeof(addr); ALOGD("Trying to accept a client"); android::base::unique_fd client( accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC)); if (client.get() < 0) { // Stop APLOGI("failed to accept client socket %d", client.get()); break; } bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get())); SSL_set_fd(ssl.get(), client.get()); ALOGD("Doing SSL handshake"); bool success = false; if (SSL_accept(ssl.get()) <= 0) { ALOGI("SSL negotiation failure"); } else { ALOGD("SSL handshake complete"); success = handleOneRequest(ssl.get()); } if (success) { // Increment queries_ as late as possible, because it represents // a query that is fully processed, and the response returned to the // client, including cleanup actions. ++queries_; } } } ALOGD("Ending loop"); } bool DnsTlsFrontend::handleOneRequest(SSL* ssl) { uint8_t queryHeader[2]; if (SSL_read(ssl, &queryHeader, 2) != 2) { ALOGI("Not enough header bytes"); return false; } const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1]; uint8_t query[qlen]; size_t qbytes = 0; while (qbytes < qlen) { int ret = SSL_read(ssl, query + qbytes, qlen - qbytes); if (ret <= 0) { ALOGI("Error while reading query"); return false; } qbytes += ret; } int sent = send(backend_socket_.get(), query, qlen, 0); if (sent != qlen) { ALOGI("Failed to send query"); return false; } const int max_size = 4096; uint8_t recv_buffer[max_size]; int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0); if (rlen <= 0) { ALOGI("Failed to receive response"); return false; } uint8_t responseHeader[2]; responseHeader[0] = rlen >> 8; responseHeader[1] = rlen; if (SSL_write(ssl, responseHeader, 2) != 2) { ALOGI("Failed to write response header"); return false; } if (SSL_write(ssl, recv_buffer, rlen) != rlen) { ALOGI("Failed to write response body"); return false; } return true; } bool DnsTlsFrontend::stopServer() { std::lock_guard lock(update_mutex_); if (!running()) { ALOGI("server not running"); return false; } ALOGI("stopping frontend"); if (!sendToEventFd()) { return false; } handler_thread_.join(); socket_.reset(); backend_socket_.reset(); event_fd_.reset(); ctx_.reset(); fingerprint_.clear(); ALOGI("frontend stopped successfully"); return true; } bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const { constexpr int intervalMs = 20; int limit = timeoutMs / intervalMs; for (int count = 0; count <= limit; ++count) { bool done = queries_ >= number; // Always sleep at least one more interval after we are done, to wait for // any immediate post-query actions that the client may take (such as // marking this server as reachable during validation). usleep(intervalMs * 1000); if (done) { // For ensuring that calls have sufficient headroom for slow machines ALOGD("Query arrived in %d/%d of allotted time", count, limit); return true; } } return false; } bool DnsTlsFrontend::sendToEventFd() { const uint64_t data = 1; if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) { APLOGI("failed to write eventfd, rt=%zd", rt); return false; } return true; } void DnsTlsFrontend::handleEventFd() { int64_t data; if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) { APLOGI("ignore reading eventfd failed, rt=%zd", rt); } } } // namespace test