C++程序  |  164行  |  4.62 KB

#include <uds/client_channel.h>

#include <sys/socket.h>

#include <algorithm>
#include <limits>
#include <random>
#include <thread>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <pdx/client.h>
#include <pdx/rpc/remote_method.h>
#include <pdx/service.h>
#include <pdx/service_dispatcher.h>

#include <uds/client_channel_factory.h>
#include <uds/service_endpoint.h>

using testing::Return;
using testing::_;

using android::pdx::ClientBase;
using android::pdx::LocalChannelHandle;
using android::pdx::LocalHandle;
using android::pdx::Message;
using android::pdx::ServiceBase;
using android::pdx::ServiceDispatcher;
using android::pdx::Status;
using android::pdx::rpc::DispatchRemoteMethod;
using android::pdx::uds::ClientChannel;
using android::pdx::uds::ClientChannelFactory;
using android::pdx::uds::Endpoint;

namespace {

struct TestProtocol {
  using DataType = int8_t;
  enum {
    kOpSum = 0,
  };
  PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
};

class TestService : public ServiceBase<TestService> {
 public:
  explicit TestService(std::unique_ptr<Endpoint> endpoint)
      : ServiceBase{"TestService", std::move(endpoint)} {}

  Status<void> HandleMessage(Message& message) override {
    switch (message.GetOp()) {
      case TestProtocol::kOpSum:
        DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
                                                message);
        return {};

      default:
        return Service::HandleMessage(message);
    }
  }

  int64_t OnSum(Message& /*message*/,
                const std::vector<TestProtocol::DataType>& data) {
    return std::accumulate(data.begin(), data.end(), int64_t{0});
  }
};

class TestClient : public ClientBase<TestClient> {
 public:
  using ClientBase::ClientBase;

  int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
    auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
    return status ? status.get() : -1;
  }
};

class TestServiceRunner {
 public:
  explicit TestServiceRunner(LocalHandle channel_socket) {
    auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
    endpoint->RegisterNewChannelForTests(std::move(channel_socket));
    service_ = TestService::Create(std::move(endpoint));
    dispatcher_ = ServiceDispatcher::Create();
    dispatcher_->AddService(service_);
    dispatch_thread_ = std::thread(
        std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
  }

  ~TestServiceRunner() {
    dispatcher_->SetCanceled(true);
    dispatch_thread_.join();
    dispatcher_->RemoveService(service_);
  }

 private:
  std::shared_ptr<TestService> service_;
  std::unique_ptr<ServiceDispatcher> dispatcher_;
  std::thread dispatch_thread_;
};

class ClientChannelTest : public testing::Test {
 public:
  void SetUp() override {
    int channel_sockets[2] = {};
    ASSERT_EQ(
        0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
    LocalHandle service_channel{channel_sockets[0]};
    LocalHandle client_channel{channel_sockets[1]};

    service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
    auto factory = ClientChannelFactory::Create(std::move(client_channel));
    auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
    ASSERT_TRUE(status);
    client_ = TestClient::Create(status.take());
  }

  void TearDown() override {
    service_runner_.reset();
    client_.reset();
  }

 protected:
  std::unique_ptr<TestServiceRunner> service_runner_;
  std::shared_ptr<TestClient> client_;
};

TEST_F(ClientChannelTest, MultithreadedClient) {
  constexpr int kNumTestThreads = 8;
  constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.

  std::random_device rd;
  std::mt19937 gen{rd()};
  std::uniform_int_distribution<TestProtocol::DataType> dist{
      std::numeric_limits<TestProtocol::DataType>::min(),
      std::numeric_limits<TestProtocol::DataType>::max()};

  auto worker = [](std::shared_ptr<TestClient> client,
                   std::vector<TestProtocol::DataType> data) {
    constexpr int kMaxIterations = 500;
    int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
    for (int i = 0; i < kMaxIterations; i++) {
      ASSERT_EQ(expected, client->Sum(data));
    }
  };

  // Start client threads.
  std::vector<TestProtocol::DataType> data;
  data.resize(kDataSize);
  std::vector<std::thread> threads;
  for (int i = 0; i < kNumTestThreads; i++) {
    std::generate(data.begin(), data.end(),
                  [&dist, &gen]() { return dist(gen); });
    threads.emplace_back(worker, client_, data);
  }

  // Wait for threads to finish.
  for (auto& thread : threads)
    thread.join();
}

}  // namespace