/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "perfetto/base/unix_socket.h"

#include <signal.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <list>
#include <thread>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "perfetto/base/build_config.h"
#include "perfetto/base/file_utils.h"
#include "perfetto/base/logging.h"
#include "perfetto/base/pipe.h"
#include "perfetto/base/temp_file.h"
#include "perfetto/base/utils.h"
#include "src/base/test/test_task_runner.h"
#include "src/ipc/test/test_socket.h"

namespace perfetto {
namespace base {
namespace {

using ::testing::_;
using ::testing::AtLeast;
using ::testing::Invoke;
using ::testing::InvokeWithoutArgs;
using ::testing::Mock;

constexpr char kSocketName[] = TEST_SOCK_NAME("unix_socket_unittest");
constexpr auto kBlocking = UnixSocket::BlockingMode::kBlocking;

class MockEventListener : public UnixSocket::EventListener {
 public:
  MOCK_METHOD2(OnNewIncomingConnection, void(UnixSocket*, UnixSocket*));
  MOCK_METHOD2(OnConnect, void(UnixSocket*, bool));
  MOCK_METHOD1(OnDisconnect, void(UnixSocket*));
  MOCK_METHOD1(OnDataAvailable, void(UnixSocket*));

  // GMock doesn't support mocking methods with non-copiable args.
  void OnNewIncomingConnection(
      UnixSocket* self,
      std::unique_ptr<UnixSocket> new_connection) override {
    incoming_connections_.emplace_back(std::move(new_connection));
    OnNewIncomingConnection(self, incoming_connections_.back().get());
  }

  std::unique_ptr<UnixSocket> GetIncomingConnection() {
    if (incoming_connections_.empty())
      return nullptr;
    std::unique_ptr<UnixSocket> sock = std::move(incoming_connections_.front());
    incoming_connections_.pop_front();
    return sock;
  }

 private:
  std::list<std::unique_ptr<UnixSocket>> incoming_connections_;
};

class UnixSocketTest : public ::testing::Test {
 protected:
  void SetUp() override { DESTROY_TEST_SOCK(kSocketName); }
  void TearDown() override { DESTROY_TEST_SOCK(kSocketName); }

  TestTaskRunner task_runner_;
  MockEventListener event_listener_;
};

TEST_F(UnixSocketTest, ConnectionFailureIfUnreachable) {
  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  ASSERT_FALSE(cli->is_connected());
  auto checkpoint = task_runner_.CreateCheckpoint("failure");
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), false))
      .WillOnce(InvokeWithoutArgs(checkpoint));
  task_runner_.RunUntilCheckpoint("failure");
}

// Both server and client should see an OnDisconnect() if the server drops
// incoming connections immediately as they are created.
TEST_F(UnixSocketTest, ConnectionImmediatelyDroppedByServer) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());

  // The server will immediately shutdown the connection upon
  // OnNewIncomingConnection().
  auto srv_did_shutdown = task_runner_.CreateCheckpoint("srv_did_shutdown");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(
          Invoke([this, srv_did_shutdown](UnixSocket*, UnixSocket* new_conn) {
            EXPECT_CALL(event_listener_, OnDisconnect(new_conn));
            new_conn->Shutdown(true);
            srv_did_shutdown();
          }));

  auto checkpoint = task_runner_.CreateCheckpoint("cli_connected");
  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
      .WillOnce(InvokeWithoutArgs(checkpoint));
  task_runner_.RunUntilCheckpoint("cli_connected");
  task_runner_.RunUntilCheckpoint("srv_did_shutdown");

  // Trying to send something will trigger the disconnection notification.
  auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
  EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
      .WillOnce(InvokeWithoutArgs(cli_disconnected));
  EXPECT_FALSE(cli->Send("whatever", kBlocking));
  task_runner_.RunUntilCheckpoint("cli_disconnected");
}

TEST_F(UnixSocketTest, ClientAndServerExchangeData) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());

  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
  auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
  auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([this, cli_connected, srv_disconnected](
                           UnixSocket*, UnixSocket* srv_conn) {
        EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
            .WillOnce(InvokeWithoutArgs(srv_disconnected));
        cli_connected();
      }));
  task_runner_.RunUntilCheckpoint("cli_connected");

  auto srv_conn = event_listener_.GetIncomingConnection();
  ASSERT_TRUE(srv_conn);
  ASSERT_TRUE(cli->is_connected());

  auto cli_did_recv = task_runner_.CreateCheckpoint("cli_did_recv");
  EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
      .WillOnce(Invoke([cli_did_recv](UnixSocket* s) {
        ASSERT_EQ("srv>cli", s->ReceiveString());
        cli_did_recv();
      }));

  auto srv_did_recv = task_runner_.CreateCheckpoint("srv_did_recv");
  EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn.get()))
      .WillOnce(Invoke([srv_did_recv](UnixSocket* s) {
        ASSERT_EQ("cli>srv", s->ReceiveString());
        srv_did_recv();
      }));
  ASSERT_TRUE(cli->Send("cli>srv", kBlocking));
  ASSERT_TRUE(srv_conn->Send("srv>cli", kBlocking));
  task_runner_.RunUntilCheckpoint("cli_did_recv");
  task_runner_.RunUntilCheckpoint("srv_did_recv");

  // Check that Send/Receive() fails gracefully once the socket is closed.
  auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
  EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
      .WillOnce(InvokeWithoutArgs(cli_disconnected));
  cli->Shutdown(true);
  char msg[4];
  ASSERT_EQ(0u, cli->Receive(&msg, sizeof(msg)));
  ASSERT_EQ("", cli->ReceiveString());
  ASSERT_EQ(0u, srv_conn->Receive(&msg, sizeof(msg)));
  ASSERT_EQ("", srv_conn->ReceiveString());
  ASSERT_FALSE(cli->Send("foo", kBlocking));
  ASSERT_FALSE(srv_conn->Send("bar", kBlocking));
  srv->Shutdown(true);
  task_runner_.RunUntilCheckpoint("cli_disconnected");
  task_runner_.RunUntilCheckpoint("srv_disconnected");
}

constexpr char cli_str[] = "cli>srv";
constexpr char srv_str[] = "srv>cli";

TEST_F(UnixSocketTest, ClientAndServerExchangeFDs) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());

  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
  auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
  auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([this, cli_connected, srv_disconnected](
                           UnixSocket*, UnixSocket* srv_conn) {
        EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
            .WillOnce(InvokeWithoutArgs(srv_disconnected));
        cli_connected();
      }));
  task_runner_.RunUntilCheckpoint("cli_connected");

  auto srv_conn = event_listener_.GetIncomingConnection();
  ASSERT_TRUE(srv_conn);
  ASSERT_TRUE(cli->is_connected());

  ScopedFile null_fd(base::OpenFile("/dev/null", O_RDONLY));
  ScopedFile zero_fd(base::OpenFile("/dev/zero", O_RDONLY));

  auto cli_did_recv = task_runner_.CreateCheckpoint("cli_did_recv");
  EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
      .WillRepeatedly(Invoke([cli_did_recv](UnixSocket* s) {
        ScopedFile fd_buf[3];
        char buf[sizeof(cli_str)];
        if (!s->Receive(buf, sizeof(buf), fd_buf, ArraySize(fd_buf)))
          return;
        ASSERT_STREQ(srv_str, buf);
        ASSERT_NE(*fd_buf[0], -1);
        ASSERT_NE(*fd_buf[1], -1);
        ASSERT_EQ(*fd_buf[2], -1);

        char rd_buf[1];
        // /dev/null
        ASSERT_EQ(read(*fd_buf[0], rd_buf, sizeof(rd_buf)), 0);
        // /dev/zero
        ASSERT_EQ(read(*fd_buf[1], rd_buf, sizeof(rd_buf)), 1);
        cli_did_recv();
      }));

  auto srv_did_recv = task_runner_.CreateCheckpoint("srv_did_recv");
  EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn.get()))
      .WillRepeatedly(Invoke([srv_did_recv](UnixSocket* s) {
        ScopedFile fd_buf[3];
        char buf[sizeof(srv_str)];
        if (!s->Receive(buf, sizeof(buf), fd_buf, ArraySize(fd_buf)))
          return;
        ASSERT_STREQ(cli_str, buf);
        ASSERT_NE(*fd_buf[0], -1);
        ASSERT_NE(*fd_buf[1], -1);
        ASSERT_EQ(*fd_buf[2], -1);

        char rd_buf[1];
        // /dev/null
        ASSERT_EQ(read(*fd_buf[0], rd_buf, sizeof(rd_buf)), 0);
        // /dev/zero
        ASSERT_EQ(read(*fd_buf[1], rd_buf, sizeof(rd_buf)), 1);
        srv_did_recv();
      }));

  int buf_fd[2] = {null_fd.get(), zero_fd.get()};

  ASSERT_TRUE(cli->Send(cli_str, sizeof(cli_str), buf_fd,
                        base::ArraySize(buf_fd), kBlocking));
  ASSERT_TRUE(srv_conn->Send(srv_str, sizeof(srv_str), buf_fd,
                             base::ArraySize(buf_fd), kBlocking));
  task_runner_.RunUntilCheckpoint("srv_did_recv");
  task_runner_.RunUntilCheckpoint("cli_did_recv");

  auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
  EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
      .WillOnce(InvokeWithoutArgs(cli_disconnected));
  cli->Shutdown(true);
  srv->Shutdown(true);
  task_runner_.RunUntilCheckpoint("srv_disconnected");
  task_runner_.RunUntilCheckpoint("cli_disconnected");
}

TEST_F(UnixSocketTest, ListenWithPassedFileDescriptor) {
  auto sock_raw = UnixSocketRaw::CreateMayFail(SockType::kStream);
  ASSERT_TRUE(sock_raw.Bind(kSocketName));
  auto fd = sock_raw.ReleaseFd();
  auto srv = UnixSocket::Listen(std::move(fd), &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());

  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
  auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
  auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([this, cli_connected, srv_disconnected](
                           UnixSocket*, UnixSocket* srv_conn) {
        // Read the EOF state.
        EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
            .WillOnce(
                InvokeWithoutArgs([srv_conn] { srv_conn->ReceiveString(); }));
        EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
            .WillOnce(InvokeWithoutArgs(srv_disconnected));
        cli_connected();
      }));
  task_runner_.RunUntilCheckpoint("cli_connected");
  ASSERT_TRUE(cli->is_connected());
  cli.reset();
  task_runner_.RunUntilCheckpoint("srv_disconnected");
}

// Mostly a stress tests. Connects kNumClients clients to the same server and
// tests that all can exchange data and can see the expected sequence of events.
TEST_F(UnixSocketTest, SeveralClients) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());
  constexpr size_t kNumClients = 32;
  std::unique_ptr<UnixSocket> cli[kNumClients];

  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .Times(kNumClients)
      .WillRepeatedly(Invoke([this](UnixSocket*, UnixSocket* s) {
        EXPECT_CALL(event_listener_, OnDataAvailable(s))
            .WillOnce(Invoke([](UnixSocket* t) {
              ASSERT_EQ("PING", t->ReceiveString());
              ASSERT_TRUE(t->Send("PONG", kBlocking));
            }));
      }));

  for (size_t i = 0; i < kNumClients; i++) {
    cli[i] = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
    EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true))
        .WillOnce(Invoke([](UnixSocket* s, bool success) {
          ASSERT_TRUE(success);
          ASSERT_TRUE(s->Send("PING", kBlocking));
        }));

    auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
    EXPECT_CALL(event_listener_, OnDataAvailable(cli[i].get()))
        .WillOnce(Invoke([checkpoint](UnixSocket* s) {
          ASSERT_EQ("PONG", s->ReceiveString());
          checkpoint();
        }));
  }

  for (size_t i = 0; i < kNumClients; i++) {
    task_runner_.RunUntilCheckpoint(std::to_string(i));
    ASSERT_TRUE(Mock::VerifyAndClearExpectations(cli[i].get()));
  }
}

// Creates two processes. The server process creates a file and passes it over
// the socket to the client. Both processes mmap the file in shared mode and
// check that they see the same contents.
TEST_F(UnixSocketTest, SharedMemory) {
  Pipe pipe = Pipe::Create();
  pid_t pid = fork();
  ASSERT_GE(pid, 0);
  constexpr size_t kTmpSize = 4096;

  if (pid == 0) {
    // Child process.
    TempFile scoped_tmp = TempFile::CreateUnlinked();
    int tmp_fd = scoped_tmp.fd();
    ASSERT_FALSE(ftruncate(tmp_fd, kTmpSize));
    char* mem = reinterpret_cast<char*>(
        mmap(nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, tmp_fd, 0));
    ASSERT_NE(nullptr, mem);
    memcpy(mem, "shm rocks", 10);

    auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
    ASSERT_TRUE(srv->is_listening());
    // Signal the other process that it can connect.
    ASSERT_EQ(1, base::WriteAll(*pipe.wr, ".", 1));
    auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_server");
    EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
        .WillOnce(Invoke(
            [this, tmp_fd, checkpoint, mem](UnixSocket*, UnixSocket* new_conn) {
              ASSERT_EQ(geteuid(), static_cast<uint32_t>(new_conn->peer_uid()));
              ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd, kBlocking));
              // Wait for the client to change this again.
              EXPECT_CALL(event_listener_, OnDataAvailable(new_conn))
                  .WillOnce(Invoke([checkpoint, mem](UnixSocket* s) {
                    ASSERT_EQ("change notify", s->ReceiveString());
                    ASSERT_STREQ("rock more", mem);
                    checkpoint();
                  }));
            }));
    task_runner_.RunUntilCheckpoint("change_seen_by_server");
    ASSERT_TRUE(Mock::VerifyAndClearExpectations(&event_listener_));
    _exit(0);
  } else {
    char sync_cmd = '\0';
    ASSERT_EQ(1, PERFETTO_EINTR(read(*pipe.rd, &sync_cmd, 1)));
    ASSERT_EQ('.', sync_cmd);
    auto cli =
        UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
    EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
    auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_client");
    EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
        .WillOnce(Invoke([checkpoint](UnixSocket* s) {
          char msg[32];
          ScopedFile fd;
          ASSERT_EQ(5u, s->Receive(msg, sizeof(msg), &fd));
          ASSERT_STREQ("txfd", msg);
          ASSERT_TRUE(fd);
          char* mem = reinterpret_cast<char*>(mmap(
              nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0));
          ASSERT_NE(nullptr, mem);
          mem[9] = '\0';  // Just to get a clean error in case of test failure.
          ASSERT_STREQ("shm rocks", mem);

          // Now change the shared memory and ping the other process.
          memcpy(mem, "rock more", 10);
          ASSERT_TRUE(s->Send("change notify", kBlocking));
          checkpoint();
        }));
    task_runner_.RunUntilCheckpoint("change_seen_by_client");
    int st = 0;
    PERFETTO_EINTR(waitpid(pid, &st, 0));
    ASSERT_FALSE(WIFSIGNALED(st)) << "Server died with signal " << WTERMSIG(st);
    EXPECT_TRUE(WIFEXITED(st));
    ASSERT_EQ(0, WEXITSTATUS(st));
  }
}

// Checks that the peer_uid() is retained after the client disconnects. The IPC
// layer needs to rely on this to validate messages received immediately before
// a client disconnects.
TEST_F(UnixSocketTest, PeerCredentialsRetainedAfterDisconnect) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());
  UnixSocket* srv_client_conn = nullptr;
  auto srv_connected = task_runner_.CreateCheckpoint("srv_connected");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke(
          [&srv_client_conn, srv_connected](UnixSocket*, UnixSocket* srv_conn) {
            srv_client_conn = srv_conn;
            EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_conn->peer_uid()));
#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
    PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
            EXPECT_EQ(getpid(), static_cast<uint32_t>(srv_conn->peer_pid()));
#endif
            srv_connected();
          }));
  auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
      .WillOnce(InvokeWithoutArgs(cli_connected));

  task_runner_.RunUntilCheckpoint("cli_connected");
  task_runner_.RunUntilCheckpoint("srv_connected");
  ASSERT_NE(nullptr, srv_client_conn);
  ASSERT_TRUE(srv_client_conn->is_connected());

  auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
  EXPECT_CALL(event_listener_, OnDisconnect(srv_client_conn))
      .WillOnce(InvokeWithoutArgs(cli_disconnected));

  // TODO(primiano): when the a peer disconnects, the other end receives a
  // spurious OnDataAvailable() that needs to be acked with a Receive() to read
  // the EOF. See b/69536434.
  EXPECT_CALL(event_listener_, OnDataAvailable(srv_client_conn))
      .WillOnce(Invoke([](UnixSocket* sock) { sock->ReceiveString(); }));

  cli.reset();
  task_runner_.RunUntilCheckpoint("cli_disconnected");
  ASSERT_FALSE(srv_client_conn->is_connected());
  EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_client_conn->peer_uid()));
#if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
    PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
  EXPECT_EQ(getpid(), static_cast<uint32_t>(srv_client_conn->peer_pid()));
#endif
}

TEST_F(UnixSocketTest, BlockingSend) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());

  auto all_frames_done = task_runner_.CreateCheckpoint("all_frames_done");
  size_t total_bytes_received = 0;
  constexpr size_t kTotalBytes = 1024 * 1024 * 4;
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([this, &total_bytes_received, all_frames_done](
                           UnixSocket*, UnixSocket* srv_conn) {
        EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
            .WillRepeatedly(
                Invoke([&total_bytes_received, all_frames_done](UnixSocket* s) {
                  char buf[1024];
                  size_t res = s->Receive(buf, sizeof(buf));
                  total_bytes_received += res;
                  if (total_bytes_received == kTotalBytes)
                    all_frames_done();
                }));
      }));

  // Override default timeout as this test can take time on the emulator.
  const int kTimeoutMs = 60000 * 3;

  // Perform the blocking send form another thread.
  std::thread tx_thread([] {
    TestTaskRunner tx_task_runner;
    MockEventListener tx_events;
    auto cli = UnixSocket::Connect(kSocketName, &tx_events, &tx_task_runner);

    auto cli_connected = tx_task_runner.CreateCheckpoint("cli_connected");
    EXPECT_CALL(tx_events, OnConnect(cli.get(), true))
        .WillOnce(InvokeWithoutArgs(cli_connected));
    tx_task_runner.RunUntilCheckpoint("cli_connected");

    auto all_sent = tx_task_runner.CreateCheckpoint("all_sent");
    char buf[1024 * 32] = {};
    tx_task_runner.PostTask([&cli, &buf, all_sent] {
      for (size_t i = 0; i < kTotalBytes / sizeof(buf); i++)
        cli->Send(buf, sizeof(buf), -1 /*fd*/, kBlocking);
      all_sent();
    });
    tx_task_runner.RunUntilCheckpoint("all_sent", kTimeoutMs);
  });

  task_runner_.RunUntilCheckpoint("all_frames_done", kTimeoutMs);
  tx_thread.join();
}

// Regression test for b/76155349 . If the receiver end disconnects while the
// sender is in the middle of a large send(), the socket should gracefully give
// up (i.e. Shutdown()) but not crash.
TEST_F(UnixSocketTest, ReceiverDisconnectsDuringSend) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());
  const int kTimeoutMs = 30000;

  auto receive_done = task_runner_.CreateCheckpoint("receive_done");
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([this, receive_done](UnixSocket*, UnixSocket* srv_conn) {
        EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
            .WillOnce(Invoke([receive_done](UnixSocket* s) {
              char buf[1024];
              size_t res = s->Receive(buf, sizeof(buf));
              ASSERT_EQ(1024u, res);
              s->Shutdown(false /*notify*/);
              receive_done();
            }));
      }));

  // Perform the blocking send form another thread.
  std::thread tx_thread([] {
    TestTaskRunner tx_task_runner;
    MockEventListener tx_events;
    auto cli = UnixSocket::Connect(kSocketName, &tx_events, &tx_task_runner);

    auto cli_connected = tx_task_runner.CreateCheckpoint("cli_connected");
    EXPECT_CALL(tx_events, OnConnect(cli.get(), true))
        .WillOnce(InvokeWithoutArgs(cli_connected));
    tx_task_runner.RunUntilCheckpoint("cli_connected");

    auto send_done = tx_task_runner.CreateCheckpoint("send_done");
    // We need a
    static constexpr size_t kBufSize = 32 * 1024 * 1024;
    std::unique_ptr<char[]> buf(new char[kBufSize]());
    tx_task_runner.PostTask([&cli, &buf, send_done] {
      bool send_res = cli->Send(buf.get(), kBufSize, -1 /*fd*/, kBlocking);
      ASSERT_FALSE(send_res);
      send_done();
    });

    tx_task_runner.RunUntilCheckpoint("send_done", kTimeoutMs);
  });
  task_runner_.RunUntilCheckpoint("receive_done", kTimeoutMs);
  tx_thread.join();
}

TEST_F(UnixSocketTest, ShiftMsgHdrSendPartialFirst) {
  // Send a part of the first iov, then send the rest.
  struct iovec iov[2] = {};
  char hello[] = "hello";
  char world[] = "world";
  iov[0].iov_base = &hello[0];
  iov[0].iov_len = base::ArraySize(hello);

  iov[1].iov_base = &world[0];
  iov[1].iov_len = base::ArraySize(world);

  struct msghdr hdr = {};
  hdr.msg_iov = iov;
  hdr.msg_iovlen = base::ArraySize(iov);

  UnixSocketRaw::ShiftMsgHdr(1, &hdr);
  EXPECT_NE(hdr.msg_iov, nullptr);
  EXPECT_EQ(hdr.msg_iov[0].iov_base, &hello[1]);
  EXPECT_EQ(hdr.msg_iov[1].iov_base, &world[0]);
  EXPECT_EQ(hdr.msg_iovlen, 2);
  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "ello");
  EXPECT_EQ(iov[0].iov_len, base::ArraySize(hello) - 1);

  UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
  EXPECT_EQ(hdr.msg_iov, &iov[1]);
  EXPECT_EQ(hdr.msg_iovlen, 1);
  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), world);
  EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world));

  UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world), &hdr);
  EXPECT_EQ(hdr.msg_iov, nullptr);
  EXPECT_EQ(hdr.msg_iovlen, 0);
}

TEST_F(UnixSocketTest, ShiftMsgHdrSendFirstAndPartial) {
  // Send first iov and part of the second iov, then send the rest.
  struct iovec iov[2] = {};
  char hello[] = "hello";
  char world[] = "world";
  iov[0].iov_base = &hello[0];
  iov[0].iov_len = base::ArraySize(hello);

  iov[1].iov_base = &world[0];
  iov[1].iov_len = base::ArraySize(world);

  struct msghdr hdr = {};
  hdr.msg_iov = iov;
  hdr.msg_iovlen = base::ArraySize(iov);

  UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
  EXPECT_NE(hdr.msg_iov, nullptr);
  EXPECT_EQ(hdr.msg_iovlen, 1);
  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "orld");
  EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world) - 1);

  UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
  EXPECT_EQ(hdr.msg_iov, nullptr);
  EXPECT_EQ(hdr.msg_iovlen, 0);
}

TEST_F(UnixSocketTest, ShiftMsgHdrSendEverything) {
  // Send everything at once.
  struct iovec iov[2] = {};
  char hello[] = "hello";
  char world[] = "world";
  iov[0].iov_base = &hello[0];
  iov[0].iov_len = base::ArraySize(hello);

  iov[1].iov_base = &world[0];
  iov[1].iov_len = base::ArraySize(world);

  struct msghdr hdr = {};
  hdr.msg_iov = iov;
  hdr.msg_iovlen = base::ArraySize(iov);

  UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello),
                             &hdr);
  EXPECT_EQ(hdr.msg_iov, nullptr);
  EXPECT_EQ(hdr.msg_iovlen, 0);
}

void Handler(int) {}

int RollbackSigaction(const struct sigaction* act) {
  return sigaction(SIGWINCH, act, nullptr);
}

TEST_F(UnixSocketTest, PartialSendMsgAll) {
  UnixSocketRaw send_sock;
  UnixSocketRaw recv_sock;
  std::tie(send_sock, recv_sock) = UnixSocketRaw::CreatePair(SockType::kStream);
  ASSERT_TRUE(send_sock);
  ASSERT_TRUE(recv_sock);

  // Set bufsize to minimum.
  int bufsize = 1024;
  ASSERT_EQ(setsockopt(send_sock.fd(), SOL_SOCKET, SO_SNDBUF, &bufsize,
                       sizeof(bufsize)),
            0);
  ASSERT_EQ(setsockopt(recv_sock.fd(), SOL_SOCKET, SO_RCVBUF, &bufsize,
                       sizeof(bufsize)),
            0);

  // Send something larger than send + recv kernel buffers combined to make
  // sendmsg block.
  char send_buf[8192];
  // Make MSAN happy.
  for (size_t i = 0; i < sizeof(send_buf); ++i)
    send_buf[i] = static_cast<char>(i % 256);
  char recv_buf[sizeof(send_buf)];

  // Need to install signal handler to cause the interrupt to happen.
  // man 3 pthread_kill:
  //   Signal dispositions are process-wide: if a signal handler is
  //   installed, the handler will be invoked in the thread thread, but if
  //   the disposition of the signal is "stop", "continue", or "terminate",
  //   this action will affect the whole process.
  struct sigaction oldact;
  struct sigaction newact = {};
  newact.sa_handler = Handler;
  ASSERT_EQ(sigaction(SIGWINCH, &newact, &oldact), 0);
  base::ScopedResource<const struct sigaction*, RollbackSigaction, nullptr>
      rollback(&oldact);

  auto blocked_thread = pthread_self();
  std::thread th([blocked_thread, &recv_sock, &recv_buf] {
    ssize_t rd = PERFETTO_EINTR(read(recv_sock.fd(), recv_buf, 1));
    ASSERT_EQ(rd, 1);
    // We are now sure the other thread is in sendmsg, interrupt send.
    ASSERT_EQ(pthread_kill(blocked_thread, SIGWINCH), 0);
    // Drain the socket to allow SendMsgAll to succeed.
    size_t offset = 1;
    while (offset < sizeof(recv_buf)) {
      rd = PERFETTO_EINTR(
          read(recv_sock.fd(), recv_buf + offset, sizeof(recv_buf) - offset));
      ASSERT_GE(rd, 0);
      offset += static_cast<size_t>(rd);
    }
  });

  // Test sending the send_buf in several chunks as an iov to exercise the
  // more complicated code-paths of SendMsgAll.
  struct msghdr hdr = {};
  struct iovec iov[4];
  static_assert(sizeof(send_buf) % base::ArraySize(iov) == 0,
                "Cannot split buffer into even pieces.");
  constexpr size_t kChunkSize = sizeof(send_buf) / base::ArraySize(iov);
  for (size_t i = 0; i < base::ArraySize(iov); ++i) {
    iov[i].iov_base = send_buf + i * kChunkSize;
    iov[i].iov_len = kChunkSize;
  }
  hdr.msg_iov = iov;
  hdr.msg_iovlen = base::ArraySize(iov);

  ASSERT_EQ(send_sock.SendMsgAll(&hdr), sizeof(send_buf));
  send_sock.Shutdown();
  th.join();
  // Make sure the re-entry logic was actually triggered.
  ASSERT_EQ(hdr.msg_iov, nullptr);
  ASSERT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
}

TEST_F(UnixSocketTest, ReleaseSocket) {
  auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
  ASSERT_TRUE(srv->is_listening());
  auto connected = task_runner_.CreateCheckpoint("connected");
  UnixSocket* peer = nullptr;
  EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
      .WillOnce(Invoke([connected, &peer](UnixSocket*, UnixSocket* new_conn) {
        peer = new_conn;
        connected();
      }));

  auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
  EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
  task_runner_.RunUntilCheckpoint("connected");
  srv->Shutdown(true);

  cli->Send("test", UnixSocket::BlockingMode::kBlocking);

  ASSERT_NE(peer, nullptr);
  auto raw_sock = peer->ReleaseSocket();

  EXPECT_CALL(event_listener_, OnDataAvailable(_)).Times(0);
  task_runner_.RunUntilIdle();

  char buf[sizeof("test")];
  ASSERT_TRUE(raw_sock);
  ASSERT_EQ(raw_sock.Receive(buf, sizeof(buf)), sizeof(buf));
  ASSERT_STREQ(buf, "test");
}

// TODO(primiano): add a test to check that in the case of a peer sending a fd
// and the other end just doing a recv (without taking it), the fd is closed and
// not left around.

// TODO(primiano); add a test to check that a socket can be reused after
// Shutdown(),

// TODO(primiano): add a test to check that OnDisconnect() is called in all
// possible cases.

// TODO(primiano): add tests that destroy the socket in all possible stages and
// verify that no spurious EventListener callback is received.

}  // namespace
}  // namespace base
}  // namespace perfetto