/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "UnixSocket.h"
#include <gtest/gtest.h>
#include <string>
#include <thread>
TEST(UnixSocket, message_buffer_smoke) {
struct Message {
uint32_t len;
uint32_t type;
char data[10];
} send_msg;
constexpr size_t send_data_size = 1024;
std::vector<char> send_data(send_data_size);
std::vector<char> read_data;
for (size_t i = 0; i < send_data_size; ++i) {
send_data[i] = i & 0xff;
}
UnixSocketMessageBuffer buffer(100);
size_t per_msg_bytes = 0;
size_t send_bytes = 0;
while (true) {
// Send data as much as possible.
while (send_bytes < send_data_size) {
size_t n = std::min(per_msg_bytes, send_data_size - send_bytes);
per_msg_bytes = (per_msg_bytes + 1) % 10;
memcpy(send_msg.data, &send_data[send_bytes], n);
send_msg.len = sizeof(UnixSocketMessage) + n;
send_msg.type = n;
if (!buffer.StoreMessage(
*reinterpret_cast<UnixSocketMessage*>(&send_msg))) {
break;
}
send_bytes += n;
}
if (buffer.Empty()) {
break;
}
// Read one message.
std::vector<char> read_buf;
auto read_func = [&](size_t size) {
while (read_buf.size() < size) {
const char* p;
size_t n = buffer.PeekData(&p);
n = std::min(n, size - read_buf.size());
read_buf.insert(read_buf.end(), p, p + n);
buffer.CommitData(n);
}
};
read_func(sizeof(UnixSocketMessage));
Message* msg = reinterpret_cast<Message*>(read_buf.data());
size_t aligned_len = Align(msg->len, UnixSocketMessageAlignment);
read_func(aligned_len);
msg = reinterpret_cast<Message*>(read_buf.data());
ASSERT_EQ(msg->len, msg->type + sizeof(UnixSocketMessage));
read_data.insert(read_data.end(), msg->data, msg->data + msg->type);
}
ASSERT_EQ(send_data, read_data);
}
static void ClientToTestUndelayedMessage(const std::string& path,
bool& client_success) {
std::unique_ptr<UnixSocketConnection> client =
UnixSocketConnection::Connect(path, true);
ASSERT_TRUE(client != nullptr);
IOEventLoop loop;
// For each message received from the server, the client replies a msg
// with type + 1.
auto receive_message_callback = [&](const UnixSocketMessage& msg) {
if (msg.len != sizeof(UnixSocketMessage)) {
return false;
}
UnixSocketMessage reply_msg;
reply_msg.len = sizeof(UnixSocketMessage);
reply_msg.type = msg.type + 1;
return client->SendMessage(reply_msg, true);
};
auto close_connection_callback = [&]() { return loop.ExitLoop(); };
ASSERT_TRUE(client->PrepareForIO(loop, receive_message_callback,
close_connection_callback));
ASSERT_TRUE(loop.RunLoop());
client_success = true;
}
TEST(UnixSocket, undelayed_message) {
std::string path = "unix_socket_test_" + std::to_string(getpid());
std::unique_ptr<UnixSocketServer> server =
UnixSocketServer::Create(path, true);
ASSERT_TRUE(server != nullptr);
bool client_success = false;
std::thread thread(
[&]() { ClientToTestUndelayedMessage(path, client_success); });
std::unique_ptr<UnixSocketConnection> conn = server->AcceptConnection();
ASSERT_TRUE(conn != nullptr);
IOEventLoop loop;
uint32_t need_reply_type = 1;
// For each message received from the client, the server replies a msg
// with type + 1, and exits when type reaches 10.
auto receive_message_callback = [&](const UnixSocketMessage& msg) {
if (msg.len != sizeof(UnixSocketMessage) || msg.type != need_reply_type) {
return false;
}
if (need_reply_type >= 10) {
return conn->NoMoreMessage();
}
UnixSocketMessage new_msg;
new_msg.len = sizeof(UnixSocketMessage);
new_msg.type = msg.type + 1;
need_reply_type = msg.type + 2;
return conn->SendMessage(new_msg, true);
};
auto close_connection_callback = [&]() { return loop.ExitLoop(); };
ASSERT_TRUE(conn->PrepareForIO(loop, receive_message_callback,
close_connection_callback));
UnixSocketMessage msg;
msg.len = sizeof(UnixSocketMessage);
msg.type = 0;
ASSERT_TRUE(conn->SendMessage(msg, true));
ASSERT_TRUE(loop.RunLoop());
thread.join();
ASSERT_TRUE(client_success);
}
static void ClientToTestBufferedMessage(const std::string& path,
bool& client_success) {
std::unique_ptr<UnixSocketConnection> client =
UnixSocketConnection::Connect(path, true);
ASSERT_TRUE(client != nullptr);
IOEventLoop loop;
// The client exits once receiving a message from the server.
auto receive_message_callback = [&](const UnixSocketMessage& msg) {
if (msg.len != sizeof(UnixSocketMessage) || msg.type != 0) {
return false;
}
return client->NoMoreMessage();
};
auto close_connection_callback = [&]() { return loop.ExitLoop(); };
ASSERT_TRUE(client->PrepareForIO(loop, receive_message_callback,
close_connection_callback));
// The client sends buffered messages until the send buffer is full.
UnixSocketMessage msg;
msg.len = sizeof(UnixSocketMessage);
msg.type = 0;
while (true) {
msg.type++;
if (!client->SendMessage(msg, false)) {
break;
}
}
ASSERT_TRUE(loop.RunLoop());
client_success = true;
}
TEST(UnixSocket, buffered_message) {
std::string path = "unix_socket_test_" + std::to_string(getpid());
std::unique_ptr<UnixSocketServer> server =
UnixSocketServer::Create(path, true);
ASSERT_TRUE(server != nullptr);
bool client_success = false;
std::thread thread(
[&]() { ClientToTestBufferedMessage(path, client_success); });
std::unique_ptr<UnixSocketConnection> conn = server->AcceptConnection();
ASSERT_TRUE(conn != nullptr);
IOEventLoop loop;
uint32_t need_reply_type = 1;
auto receive_message_callback = [&](const UnixSocketMessage& msg) {
// The server checks if the type of received message is increased by one
// each time.
if (msg.len != sizeof(UnixSocketMessage) || msg.type != need_reply_type) {
return false;
}
if (need_reply_type == 1) {
// Notify the client to exit.
UnixSocketMessage new_msg;
new_msg.len = sizeof(UnixSocketMessage);
new_msg.type = 0;
if (!conn->SendMessage(new_msg, true)) {
return false;
}
}
need_reply_type++;
return true;
};
auto close_connection_callback = [&]() { return loop.ExitLoop(); };
ASSERT_TRUE(conn->PrepareForIO(loop, receive_message_callback,
close_connection_callback));
ASSERT_TRUE(loop.RunLoop());
thread.join();
ASSERT_TRUE(client_success);
}