#include "uds/client_channel.h"
#include <errno.h>
#include <log/log.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <pdx/client.h>
#include <pdx/service_endpoint.h>
#include <uds/ipc_helper.h>
namespace android {
namespace pdx {
namespace uds {
namespace {
struct TransactionState {
bool GetLocalFileHandle(int index, LocalHandle* handle) {
if (index < 0) {
handle->Reset(index);
} else if (static_cast<size_t>(index) < response.file_descriptors.size()) {
*handle = std::move(response.file_descriptors[index]);
} else {
return false;
}
return true;
}
bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
if (index < 0) {
*handle = LocalChannelHandle{nullptr, index};
} else if (static_cast<size_t>(index) < response.channels.size()) {
auto& channel_info = response.channels[index];
*handle = ChannelManager::Get().CreateHandle(
std::move(channel_info.data_fd), std::move(channel_info.event_fd));
} else {
return false;
}
return true;
}
FileReference PushFileHandle(BorrowedHandle handle) {
if (!handle)
return handle.Get();
request.file_descriptors.push_back(std::move(handle));
return request.file_descriptors.size() - 1;
}
ChannelReference PushChannelHandle(BorrowedChannelHandle handle) {
if (!handle)
return handle.value();
if (auto* channel_data =
ChannelManager::Get().GetChannelData(handle.value())) {
ChannelInfo<BorrowedHandle> channel_info;
channel_info.data_fd.Reset(handle.value());
channel_info.event_fd = channel_data->event_receiver.event_fd();
request.channels.push_back(std::move(channel_info));
return request.channels.size() - 1;
} else {
return -1;
}
}
RequestHeader<BorrowedHandle> request;
ResponseHeader<LocalHandle> response;
};
Status<void> ReadAndDiscardData(const BorrowedHandle& socket_fd, size_t size) {
while (size > 0) {
// If there is more data to read in the message than the buffers provided
// by the caller, read and discard the extra data from the socket.
char buffer[1024];
size_t size_to_read = std::min(sizeof(buffer), size);
auto status = ReceiveData(socket_fd, buffer, size_to_read);
if (!status)
return status;
size -= size_to_read;
}
// We still want to return EIO error to the caller in case we had unexpected
// data in the socket stream.
return ErrorStatus(EIO);
}
Status<void> SendRequest(const BorrowedHandle& socket_fd,
TransactionState* transaction_state, int opcode,
const iovec* send_vector, size_t send_count,
size_t max_recv_len) {
size_t send_len = CountVectorSize(send_vector, send_count);
InitRequest(&transaction_state->request, opcode, send_len, max_recv_len,
false);
auto status = SendData(socket_fd, transaction_state->request);
if (status && send_len > 0)
status = SendDataVector(socket_fd, send_vector, send_count);
return status;
}
Status<void> ReceiveResponse(const BorrowedHandle& socket_fd,
TransactionState* transaction_state,
const iovec* receive_vector, size_t receive_count,
size_t max_recv_len) {
auto status = ReceiveData(socket_fd, &transaction_state->response);
if (!status)
return status;
if (transaction_state->response.recv_len > 0) {
std::vector<iovec> read_buffers;
size_t size_remaining = 0;
if (transaction_state->response.recv_len != max_recv_len) {
// If the receive buffer not exactly the size of data available, recreate
// the vector list to consume the data exactly since ReceiveDataVector()
// validates that the number of bytes received equals the number of bytes
// requested.
size_remaining = transaction_state->response.recv_len;
for (size_t i = 0; i < receive_count && size_remaining > 0; i++) {
read_buffers.push_back(receive_vector[i]);
iovec& last_vec = read_buffers.back();
if (last_vec.iov_len > size_remaining)
last_vec.iov_len = size_remaining;
size_remaining -= last_vec.iov_len;
}
receive_vector = read_buffers.data();
receive_count = read_buffers.size();
}
status = ReceiveDataVector(socket_fd, receive_vector, receive_count);
if (status && size_remaining > 0)
status = ReadAndDiscardData(socket_fd, size_remaining);
}
return status;
}
} // anonymous namespace
ClientChannel::ClientChannel(LocalChannelHandle channel_handle)
: channel_handle_{std::move(channel_handle)} {
channel_data_ = ChannelManager::Get().GetChannelData(channel_handle_.value());
}
std::unique_ptr<pdx::ClientChannel> ClientChannel::Create(
LocalChannelHandle channel_handle) {
return std::unique_ptr<pdx::ClientChannel>{
new ClientChannel{std::move(channel_handle)}};
}
ClientChannel::~ClientChannel() {
if (channel_handle_)
shutdown(channel_handle_.value(), SHUT_WR);
}
void* ClientChannel::AllocateTransactionState() { return new TransactionState; }
void ClientChannel::FreeTransactionState(void* state) {
delete static_cast<TransactionState*>(state);
}
Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
size_t length) {
std::unique_lock<std::mutex> lock(socket_mutex_);
Status<void> status;
android::pdx::uds::RequestHeader<BorrowedHandle> request;
if (length > request.impulse_payload.size() ||
(buffer == nullptr && length != 0)) {
status.SetError(EINVAL);
return status;
}
InitRequest(&request, opcode, length, 0, true);
memcpy(request.impulse_payload.data(), buffer, length);
return SendData(BorrowedHandle{channel_handle_.value()}, request);
}
Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
const iovec* send_vector,
size_t send_count,
const iovec* receive_vector,
size_t receive_count) {
std::unique_lock<std::mutex> lock(socket_mutex_);
Status<int> result;
if ((send_vector == nullptr && send_count != 0) ||
(receive_vector == nullptr && receive_count != 0)) {
result.SetError(EINVAL);
return result;
}
auto* state = static_cast<TransactionState*>(transaction_state);
size_t max_recv_len = CountVectorSize(receive_vector, receive_count);
auto status = SendRequest(BorrowedHandle{channel_handle_.value()}, state,
opcode, send_vector, send_count, max_recv_len);
if (status) {
status = ReceiveResponse(BorrowedHandle{channel_handle_.value()}, state,
receive_vector, receive_count, max_recv_len);
}
if (!result.PropagateError(status)) {
const int return_code = state->response.ret_code;
if (return_code >= 0)
result.SetValue(return_code);
else
result.SetError(-return_code);
}
return result;
}
Status<int> ClientChannel::SendWithInt(void* transaction_state, int opcode,
const iovec* send_vector,
size_t send_count,
const iovec* receive_vector,
size_t receive_count) {
return SendAndReceive(transaction_state, opcode, send_vector, send_count,
receive_vector, receive_count);
}
Status<LocalHandle> ClientChannel::SendWithFileHandle(
void* transaction_state, int opcode, const iovec* send_vector,
size_t send_count, const iovec* receive_vector, size_t receive_count) {
Status<int> int_status =
SendAndReceive(transaction_state, opcode, send_vector, send_count,
receive_vector, receive_count);
Status<LocalHandle> status;
if (status.PropagateError(int_status))
return status;
auto* state = static_cast<TransactionState*>(transaction_state);
LocalHandle handle;
if (state->GetLocalFileHandle(int_status.get(), &handle)) {
status.SetValue(std::move(handle));
} else {
status.SetError(EINVAL);
}
return status;
}
Status<LocalChannelHandle> ClientChannel::SendWithChannelHandle(
void* transaction_state, int opcode, const iovec* send_vector,
size_t send_count, const iovec* receive_vector, size_t receive_count) {
Status<int> int_status =
SendAndReceive(transaction_state, opcode, send_vector, send_count,
receive_vector, receive_count);
Status<LocalChannelHandle> status;
if (status.PropagateError(int_status))
return status;
auto* state = static_cast<TransactionState*>(transaction_state);
LocalChannelHandle handle;
if (state->GetLocalChannelHandle(int_status.get(), &handle)) {
status.SetValue(std::move(handle));
} else {
status.SetError(EINVAL);
}
return status;
}
FileReference ClientChannel::PushFileHandle(void* transaction_state,
const LocalHandle& handle) {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->PushFileHandle(handle.Borrow());
}
FileReference ClientChannel::PushFileHandle(void* transaction_state,
const BorrowedHandle& handle) {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->PushFileHandle(handle.Duplicate());
}
ChannelReference ClientChannel::PushChannelHandle(
void* transaction_state, const LocalChannelHandle& handle) {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->PushChannelHandle(handle.Borrow());
}
ChannelReference ClientChannel::PushChannelHandle(
void* transaction_state, const BorrowedChannelHandle& handle) {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->PushChannelHandle(handle.Duplicate());
}
bool ClientChannel::GetFileHandle(void* transaction_state, FileReference ref,
LocalHandle* handle) const {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->GetLocalFileHandle(ref, handle);
}
bool ClientChannel::GetChannelHandle(void* transaction_state,
ChannelReference ref,
LocalChannelHandle* handle) const {
auto* state = static_cast<TransactionState*>(transaction_state);
return state->GetLocalChannelHandle(ref, handle);
}
} // namespace uds
} // namespace pdx
} // namespace android