/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "chre_host/socket_server.h"
#include <poll.h>
#include <cassert>
#include <cinttypes>
#include <csignal>
#include <cstdlib>
#include <map>
#include <mutex>
#include <cutils/sockets.h>
#include "chre_host/log.h"
namespace android {
namespace chre {
std::atomic<bool> SocketServer::sSignalReceived(false);
namespace {
void maskAllSignals() {
sigset_t signalMask;
sigfillset(&signalMask);
if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
LOG_ERROR("Couldn't mask all signals", errno);
}
}
void maskAllSignalsExceptIntAndTerm() {
sigset_t signalMask;
sigfillset(&signalMask);
sigdelset(&signalMask, SIGINT);
sigdelset(&signalMask, SIGTERM);
if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
}
}
} // anonymous namespace
SocketServer::SocketServer() {
// Initialize the socket fds field for all inactive client slots to -1, so
// poll skips over it, and we don't attempt to send on it
for (size_t i = 1; i <= kMaxActiveClients; i++) {
mPollFds[i].fd = -1;
mPollFds[i].events = POLLIN;
}
}
void SocketServer::run(const char *socketName, bool allowSocketCreation,
ClientMessageCallback clientMessageCallback) {
mClientMessageCallback = clientMessageCallback;
mSockFd = android_get_control_socket(socketName);
if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
LOGI("Didn't inherit socket, creating...");
mSockFd = socket_local_server(socketName,
ANDROID_SOCKET_NAMESPACE_RESERVED,
SOCK_SEQPACKET);
}
if (mSockFd == INVALID_SOCKET) {
LOGE("Couldn't get/create socket");
} else {
int ret = listen(mSockFd, kMaxPendingConnectionRequests);
if (ret < 0) {
LOG_ERROR("Couldn't listen on socket", errno);
} else {
serviceSocket();
}
{
std::lock_guard<std::mutex> lock(mClientsMutex);
for (const auto& pair : mClients) {
int clientSocket = pair.first;
if (close(clientSocket) != 0) {
LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
pair.second.clientId, strerror(errno));
}
}
mClients.clear();
}
close(mSockFd);
}
}
void SocketServer::sendToAllClients(const void *data, size_t length) {
std::lock_guard<std::mutex> lock(mClientsMutex);
int deliveredCount = 0;
for (const auto& pair : mClients) {
int clientSocket = pair.first;
uint16_t clientId = pair.second.clientId;
if (sendToClientSocket(data, length, clientSocket, clientId)) {
deliveredCount++;
} else if (errno == EINTR) {
// Exit early if we were interrupted - we should only get this for
// SIGINT/SIGTERM, so we should exit quickly
break;
}
}
if (deliveredCount == 0) {
LOGW("Got message but didn't deliver to any clients");
}
}
bool SocketServer::sendToClientById(const void *data, size_t length,
uint16_t clientId) {
std::lock_guard<std::mutex> lock(mClientsMutex);
bool sent = false;
for (const auto& pair : mClients) {
uint16_t thisClientId = pair.second.clientId;
if (thisClientId == clientId) {
int clientSocket = pair.first;
sent = sendToClientSocket(data, length, clientSocket, thisClientId);
break;
}
}
return sent;
}
void SocketServer::acceptClientConnection() {
int clientSocket = accept(mSockFd, NULL, NULL);
if (clientSocket < 0) {
LOG_ERROR("Couldn't accept client connection", errno);
} else if (mClients.size() >= kMaxActiveClients) {
LOGW("Rejecting client request - maximum number of clients reached");
close(clientSocket);
} else {
ClientData clientData;
clientData.clientId = mNextClientId++;
// We currently don't handle wraparound - if we're getting this many
// connects/disconnects, then something is wrong.
// TODO: can handle this properly by iterating over the existing clients to
// avoid a conflict.
if (clientData.clientId == 0) {
LOGE("Couldn't allocate client ID");
std::exit(-1);
}
bool slotFound = false;
for (size_t i = 1; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd < 0) {
mPollFds[i].fd = clientSocket;
slotFound = true;
break;
}
}
if (!slotFound) {
LOGE("Couldn't find slot for client!");
assert(slotFound);
close(clientSocket);
} else {
{
std::lock_guard<std::mutex> lock(mClientsMutex);
mClients[clientSocket] = clientData;
}
LOGI("Accepted new client connection (count %zu), assigned client ID %"
PRIu16, mClients.size(), clientData.clientId);
}
}
}
void SocketServer::handleClientData(int clientSocket) {
const ClientData& clientData = mClients[clientSocket];
uint16_t clientId = clientData.clientId;
ssize_t packetSize = recv(
clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
if (packetSize < 0) {
LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
strerror(errno));
} else if (packetSize == 0) {
LOGI("Client %" PRIu16 " disconnected", clientId);
disconnectClient(clientSocket);
} else {
LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
}
}
void SocketServer::disconnectClient(int clientSocket) {
{
std::lock_guard<std::mutex> lock(mClientsMutex);
mClients.erase(clientSocket);
}
close(clientSocket);
bool removed = false;
for (size_t i = 1; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd == clientSocket) {
mPollFds[i].fd = -1;
removed = true;
break;
}
}
if (!removed) {
LOGE("Out of sync");
assert(removed);
}
}
bool SocketServer::sendToClientSocket(const void *data, size_t length,
int clientSocket, uint16_t clientId) {
errno = 0;
ssize_t bytesSent = send(clientSocket, data, length, 0);
if (bytesSent < 0) {
LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
length, clientId, strerror(errno));
} else if (bytesSent == 0) {
LOGW("Client %" PRIu16 " disconnected before message could be delivered",
clientId);
} else {
LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
clientId);
}
return (bytesSent > 0);
}
void SocketServer::serviceSocket() {
constexpr size_t kListenIndex = 0;
static_assert(kListenIndex == 0, "Code assumes that the first index is "
"always the listen socket");
mPollFds[kListenIndex].fd = mSockFd;
mPollFds[kListenIndex].events = POLLIN;
// Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
// and ignore other signals
sigset_t signalMask;
sigfillset(&signalMask);
sigdelset(&signalMask, SIGINT);
sigdelset(&signalMask, SIGTERM);
// Masking signals here ensure that after this point, we won't handle INT/TERM
// until after we call into ppoll()
maskAllSignals();
std::signal(SIGINT, signalHandler);
std::signal(SIGTERM, signalHandler);
LOGI("Ready to accept connections");
while (!sSignalReceived) {
int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
maskAllSignalsExceptIntAndTerm();
if (ret == -1) {
LOGI("Exiting poll loop: %s", strerror(errno));
break;
}
if (mPollFds[kListenIndex].revents & POLLIN) {
acceptClientConnection();
}
for (size_t i = 1; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd < 0) {
continue;
}
if (mPollFds[i].revents & POLLIN) {
handleClientData(mPollFds[i].fd);
}
}
// Mask all signals to ensure that sSignalReceived can't become true between
// checking it in the while condition and calling into ppoll()
maskAllSignals();
}
}
void SocketServer::signalHandler(int signal) {
LOGD("Caught signal %d", signal);
sSignalReceived = true;
}
} // namespace chre
} // namespace android