// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/edk/system/channel.h"
#include <stdint.h>
#include <windows.h>
#include <algorithm>
#include <deque>
#include <limits>
#include <memory>
#include "base/bind.h"
#include "base/location.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop.h"
#include "base/synchronization/lock.h"
#include "base/task_runner.h"
#include "base/win/win_util.h"
#include "mojo/edk/embedder/platform_handle_vector.h"
namespace mojo {
namespace edk {
namespace {
// A view over a Channel::Message object. The write queue uses these since
// large messages may need to be sent in chunks.
class MessageView {
public:
// Owns |message|. |offset| indexes the first unsent byte in the message.
MessageView(Channel::MessagePtr message, size_t offset)
: message_(std::move(message)),
offset_(offset) {
DCHECK_GT(message_->data_num_bytes(), offset_);
}
MessageView(MessageView&& other) { *this = std::move(other); }
MessageView& operator=(MessageView&& other) {
message_ = std::move(other.message_);
offset_ = other.offset_;
return *this;
}
~MessageView() {}
const void* data() const {
return static_cast<const char*>(message_->data()) + offset_;
}
size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; }
size_t data_offset() const { return offset_; }
void advance_data_offset(size_t num_bytes) {
DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes);
offset_ += num_bytes;
}
Channel::MessagePtr TakeChannelMessage() { return std::move(message_); }
private:
Channel::MessagePtr message_;
size_t offset_;
DISALLOW_COPY_AND_ASSIGN(MessageView);
};
class ChannelWin : public Channel,
public base::MessageLoop::DestructionObserver,
public base::MessageLoopForIO::IOHandler {
public:
ChannelWin(Delegate* delegate,
ScopedPlatformHandle handle,
scoped_refptr<base::TaskRunner> io_task_runner)
: Channel(delegate),
self_(this),
handle_(std::move(handle)),
io_task_runner_(io_task_runner) {
CHECK(handle_.is_valid());
wait_for_connect_ = handle_.get().needs_connection;
}
void Start() override {
io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this));
}
void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this));
}
void Write(MessagePtr message) override {
bool write_error = false;
{
base::AutoLock lock(write_lock_);
if (reject_writes_)
return;
bool write_now = !delay_writes_ && outgoing_messages_.empty();
outgoing_messages_.emplace_back(std::move(message), 0);
if (write_now && !WriteNoLock(outgoing_messages_.front()))
reject_writes_ = write_error = true;
}
if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by
// the delegate and we don't want to re-enter it.
io_task_runner_->PostTask(FROM_HERE,
base::Bind(&ChannelWin::OnError, this));
}
}
void LeakHandle() override {
DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
leak_handle_ = true;
}
bool GetReadPlatformHandles(
size_t num_handles,
const void* extra_header,
size_t extra_header_size,
ScopedPlatformHandleVectorPtr* handles) override {
if (num_handles > std::numeric_limits<uint16_t>::max())
return false;
using HandleEntry = Channel::Message::HandleEntry;
size_t handles_size = sizeof(HandleEntry) * num_handles;
if (handles_size > extra_header_size)
return false;
DCHECK(extra_header);
handles->reset(new PlatformHandleVector(num_handles));
const HandleEntry* extra_header_handles =
reinterpret_cast<const HandleEntry*>(extra_header);
for (size_t i = 0; i < num_handles; i++) {
(*handles)->at(i).handle =
base::win::Uint32ToHandle(extra_header_handles[i].handle);
}
return true;
}
private:
// May run on any thread.
~ChannelWin() override {}
void StartOnIOThread() {
base::MessageLoop::current()->AddDestructionObserver(this);
base::MessageLoopForIO::current()->RegisterIOHandler(
handle_.get().handle, this);
if (wait_for_connect_) {
BOOL ok = ConnectNamedPipe(handle_.get().handle,
&connect_context_.overlapped);
if (ok) {
PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
OnError();
return;
}
const DWORD err = GetLastError();
switch (err) {
case ERROR_PIPE_CONNECTED:
wait_for_connect_ = false;
break;
case ERROR_IO_PENDING:
AddRef();
return;
case ERROR_NO_DATA:
OnError();
return;
}
}
// Now that we have registered our IOHandler, we can start writing.
{
base::AutoLock lock(write_lock_);
if (delay_writes_) {
delay_writes_ = false;
WriteNextNoLock();
}
}
// Keep this alive in case we synchronously run shutdown.
scoped_refptr<ChannelWin> keep_alive(this);
ReadMore(0);
}
void ShutDownOnIOThread() {
base::MessageLoop::current()->RemoveDestructionObserver(this);
// BUG(crbug.com/583525): This function is expected to be called once, and
// |handle_| should be valid at this point.
CHECK(handle_.is_valid());
CancelIo(handle_.get().handle);
if (leak_handle_)
ignore_result(handle_.release());
handle_.reset();
// May destroy the |this| if it was the last reference.
self_ = nullptr;
}
// base::MessageLoop::DestructionObserver:
void WillDestroyCurrentMessageLoop() override {
DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
if (self_)
ShutDownOnIOThread();
}
// base::MessageLoop::IOHandler:
void OnIOCompleted(base::MessageLoopForIO::IOContext* context,
DWORD bytes_transfered,
DWORD error) override {
if (error != ERROR_SUCCESS) {
OnError();
} else if (context == &connect_context_) {
DCHECK(wait_for_connect_);
wait_for_connect_ = false;
ReadMore(0);
base::AutoLock lock(write_lock_);
if (delay_writes_) {
delay_writes_ = false;
WriteNextNoLock();
}
} else if (context == &read_context_) {
OnReadDone(static_cast<size_t>(bytes_transfered));
} else {
CHECK(context == &write_context_);
OnWriteDone(static_cast<size_t>(bytes_transfered));
}
Release(); // Balancing reference taken after ReadFile / WriteFile.
}
void OnReadDone(size_t bytes_read) {
if (bytes_read > 0) {
size_t next_read_size = 0;
if (OnReadComplete(bytes_read, &next_read_size)) {
ReadMore(next_read_size);
} else {
OnError();
}
} else if (bytes_read == 0) {
OnError();
}
}
void OnWriteDone(size_t bytes_written) {
if (bytes_written == 0)
return;
bool write_error = false;
{
base::AutoLock lock(write_lock_);
DCHECK(!outgoing_messages_.empty());
MessageView& message_view = outgoing_messages_.front();
message_view.advance_data_offset(bytes_written);
if (message_view.data_num_bytes() == 0) {
Channel::MessagePtr message = message_view.TakeChannelMessage();
outgoing_messages_.pop_front();
// Clear any handles so they don't get closed on destruction.
ScopedPlatformHandleVectorPtr handles = message->TakeHandles();
if (handles)
handles->clear();
}
if (!WriteNextNoLock())
reject_writes_ = write_error = true;
}
if (write_error)
OnError();
}
void ReadMore(size_t next_read_size_hint) {
size_t buffer_capacity = next_read_size_hint;
char* buffer = GetReadBuffer(&buffer_capacity);
DCHECK_GT(buffer_capacity, 0u);
BOOL ok = ReadFile(handle_.get().handle,
buffer,
static_cast<DWORD>(buffer_capacity),
NULL,
&read_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted
} else {
OnError();
}
}
// Attempts to write a message directly to the channel. If the full message
// cannot be written, it's queued and a wait is initiated to write the message
// ASAP on the I/O thread.
bool WriteNoLock(const MessageView& message_view) {
BOOL ok = WriteFile(handle_.get().handle,
message_view.data(),
static_cast<DWORD>(message_view.data_num_bytes()),
NULL,
&write_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted.
return true;
}
return false;
}
bool WriteNextNoLock() {
if (outgoing_messages_.empty())
return true;
return WriteNoLock(outgoing_messages_.front());
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_;
ScopedPlatformHandle handle_;
scoped_refptr<base::TaskRunner> io_task_runner_;
base::MessageLoopForIO::IOContext connect_context_;
base::MessageLoopForIO::IOContext read_context_;
base::MessageLoopForIO::IOContext write_context_;
// Protects |reject_writes_| and |outgoing_messages_|.
base::Lock write_lock_;
bool delay_writes_ = true;
bool reject_writes_ = false;
std::deque<MessageView> outgoing_messages_;
bool wait_for_connect_;
bool leak_handle_ = false;
DISALLOW_COPY_AND_ASSIGN(ChannelWin);
};
} // namespace
// static
scoped_refptr<Channel> Channel::Create(
Delegate* delegate,
ConnectionParams connection_params,
scoped_refptr<base::TaskRunner> io_task_runner) {
return new ChannelWin(delegate, connection_params.TakeChannelHandle(),
io_task_runner);
}
} // namespace edk
} // namespace mojo