// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <sys/socket.h>
#include "base/bind.h"
#include "base/files/file_path.h"
#include "base/path_service.h"
#include "base/posix/eintr_wrapper.h"
#include "base/synchronization/waitable_event.h"
#include "base/threading/thread.h"
#include "base/threading/thread_restrictions.h"
#include "ipc/unix_domain_socket_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
class SocketAcceptor : public base::MessageLoopForIO::Watcher {
public:
SocketAcceptor(int fd, base::MessageLoopProxy* target_thread)
: server_fd_(-1),
target_thread_(target_thread),
started_watching_event_(false, false),
accepted_event_(false, false) {
target_thread->PostTask(FROM_HERE,
base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd));
}
virtual ~SocketAcceptor() {
Close();
}
int server_fd() const { return server_fd_; }
void WaitUntilReady() {
started_watching_event_.Wait();
}
void WaitForAccept() {
accepted_event_.Wait();
}
void Close() {
if (watcher_.get()) {
target_thread_->PostTask(FROM_HERE,
base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this),
watcher_.release()));
}
}
private:
void StartWatching(int fd) {
watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher);
base::MessageLoopForIO::current()->WatchFileDescriptor(
fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this);
started_watching_event_.Signal();
}
void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) {
watcher->StopWatchingFileDescriptor();
delete watcher;
}
virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE {
ASSERT_EQ(-1, server_fd_);
IPC::ServerAcceptConnection(fd, &server_fd_);
watcher_->StopWatchingFileDescriptor();
accepted_event_.Signal();
}
virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
int server_fd_;
base::MessageLoopProxy* target_thread_;
scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_;
base::WaitableEvent started_watching_event_;
base::WaitableEvent accepted_event_;
DISALLOW_COPY_AND_ASSIGN(SocketAcceptor);
};
const base::FilePath GetChannelDir() {
#if defined(OS_ANDROID)
base::FilePath tmp_dir;
PathService::Get(base::DIR_CACHE, &tmp_dir);
return tmp_dir;
#else
return base::FilePath("/var/tmp");
#endif
}
class TestUnixSocketConnection {
public:
TestUnixSocketConnection()
: worker_("WorkerThread"),
server_listen_fd_(-1),
server_fd_(-1),
client_fd_(-1) {
socket_name_ = GetChannelDir().Append("TestSocket");
base::Thread::Options options;
options.message_loop_type = base::MessageLoop::TYPE_IO;
worker_.StartWithOptions(options);
}
bool CreateServerSocket() {
IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_);
if (server_listen_fd_ < 0)
return false;
struct stat socket_stat;
stat(socket_name_.value().c_str(), &socket_stat);
EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode));
acceptor_.reset(new SocketAcceptor(server_listen_fd_,
worker_.message_loop_proxy().get()));
acceptor_->WaitUntilReady();
return true;
}
bool CreateClientSocket() {
DCHECK(server_listen_fd_ >= 0);
IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_);
if (client_fd_ < 0)
return false;
acceptor_->WaitForAccept();
server_fd_ = acceptor_->server_fd();
return server_fd_ >= 0;
}
virtual ~TestUnixSocketConnection() {
if (client_fd_ >= 0)
close(client_fd_);
if (server_fd_ >= 0)
close(server_fd_);
if (server_listen_fd_ >= 0) {
close(server_listen_fd_);
unlink(socket_name_.value().c_str());
}
}
int client_fd() const { return client_fd_; }
int server_fd() const { return server_fd_; }
private:
base::Thread worker_;
base::FilePath socket_name_;
int server_listen_fd_;
int server_fd_;
int client_fd_;
scoped_ptr<SocketAcceptor> acceptor_;
};
// Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
// IPC::CreateClientUnixDomainSocket can successfully connect to.
TEST(UnixDomainSocketUtil, Connect) {
TestUnixSocketConnection connection;
ASSERT_TRUE(connection.CreateServerSocket());
ASSERT_TRUE(connection.CreateClientSocket());
}
// Ensure that messages can be sent across the resulting socket.
TEST(UnixDomainSocketUtil, SendReceive) {
TestUnixSocketConnection connection;
ASSERT_TRUE(connection.CreateServerSocket());
ASSERT_TRUE(connection.CreateClientSocket());
const char buffer[] = "Hello, server!";
size_t buf_len = sizeof(buffer);
size_t sent_bytes =
HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0));
ASSERT_EQ(buf_len, sent_bytes);
char recv_buf[sizeof(buffer)];
size_t received_bytes =
HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0));
ASSERT_EQ(buf_len, received_bytes);
ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len));
}
} // namespace