// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "mojo/edk/system/broker_host.h"

#include <utility>

#include "base/logging.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop.h"
#include "base/threading/thread_task_runner_handle.h"
#include "mojo/edk/embedder/named_platform_channel_pair.h"
#include "mojo/edk/embedder/named_platform_handle.h"
#include "mojo/edk/embedder/platform_handle_vector.h"
#include "mojo/edk/embedder/platform_shared_buffer.h"
#include "mojo/edk/system/broker_messages.h"

namespace mojo {
namespace edk {

BrokerHost::BrokerHost(base::ProcessHandle client_process,
                       ScopedPlatformHandle platform_handle)
#if defined(OS_WIN)
    : client_process_(client_process)
#endif
{
  CHECK(platform_handle.is_valid());

  base::MessageLoop::current()->AddDestructionObserver(this);

  channel_ = Channel::Create(this, ConnectionParams(std::move(platform_handle)),
                             base::ThreadTaskRunnerHandle::Get());
  channel_->Start();
}

BrokerHost::~BrokerHost() {
  // We're always destroyed on the creation thread, which is the IO thread.
  base::MessageLoop::current()->RemoveDestructionObserver(this);

  if (channel_)
    channel_->ShutDown();
}

bool BrokerHost::PrepareHandlesForClient(PlatformHandleVector* handles) {
#if defined(OS_WIN)
  if (!Channel::Message::RewriteHandles(
      base::GetCurrentProcessHandle(), client_process_, handles)) {
    // NOTE: We only log an error here. We do not signal a logical error or
    // prevent any message from being sent. The client should handle unexpected
    // invalid handles appropriately.
    DLOG(ERROR) << "Failed to rewrite one or more handles to broker client.";
    return false;
  }
#endif
  return true;
}

bool BrokerHost::SendChannel(ScopedPlatformHandle handle) {
  CHECK(handle.is_valid());
  CHECK(channel_);

#if defined(OS_WIN)
  InitData* data;
  Channel::MessagePtr message =
      CreateBrokerMessage(BrokerMessageType::INIT, 1, 0, &data);
  data->pipe_name_length = 0;
#else
  Channel::MessagePtr message =
      CreateBrokerMessage(BrokerMessageType::INIT, 1, nullptr);
#endif
  ScopedPlatformHandleVectorPtr handles;
  handles.reset(new PlatformHandleVector(1));
  handles->at(0) = handle.release();

  // This may legitimately fail on Windows if the client process is in another
  // session, e.g., is an elevated process.
  if (!PrepareHandlesForClient(handles.get()))
    return false;

   message->SetHandles(std::move(handles));
  channel_->Write(std::move(message));
  return true;
}

#if defined(OS_WIN)

void BrokerHost::SendNamedChannel(const base::StringPiece16& pipe_name) {
  InitData* data;
  base::char16* name_data;
  Channel::MessagePtr message = CreateBrokerMessage(
      BrokerMessageType::INIT, 0, sizeof(*name_data) * pipe_name.length(),
      &data, reinterpret_cast<void**>(&name_data));
  data->pipe_name_length = static_cast<uint32_t>(pipe_name.length());
  std::copy(pipe_name.begin(), pipe_name.end(), name_data);
  channel_->Write(std::move(message));
}

#endif  // defined(OS_WIN)

void BrokerHost::OnBufferRequest(uint32_t num_bytes) {
  scoped_refptr<PlatformSharedBuffer> read_only_buffer;
  scoped_refptr<PlatformSharedBuffer> buffer =
      PlatformSharedBuffer::Create(num_bytes);
  if (buffer)
    read_only_buffer = buffer->CreateReadOnlyDuplicate();
  if (!read_only_buffer)
    buffer = nullptr;

  Channel::MessagePtr message = CreateBrokerMessage(
      BrokerMessageType::BUFFER_RESPONSE, buffer ? 2 : 0, nullptr);
  if (buffer) {
    ScopedPlatformHandleVectorPtr handles;
    handles.reset(new PlatformHandleVector(2));
    handles->at(0) = buffer->PassPlatformHandle().release();
    handles->at(1) = read_only_buffer->PassPlatformHandle().release();
    PrepareHandlesForClient(handles.get());
    message->SetHandles(std::move(handles));
  }

  channel_->Write(std::move(message));
}

void BrokerHost::OnChannelMessage(const void* payload,
                                  size_t payload_size,
                                  ScopedPlatformHandleVectorPtr handles) {
  if (payload_size < sizeof(BrokerMessageHeader))
    return;

  const BrokerMessageHeader* header =
      static_cast<const BrokerMessageHeader*>(payload);
  switch (header->type) {
    case BrokerMessageType::BUFFER_REQUEST:
      if (payload_size ==
            sizeof(BrokerMessageHeader) + sizeof(BufferRequestData)) {
        const BufferRequestData* request =
            reinterpret_cast<const BufferRequestData*>(header + 1);
        OnBufferRequest(request->size);
      }
      break;

    default:
      LOG(ERROR) << "Unexpected broker message type: " << header->type;
      break;
  }
}

void BrokerHost::OnChannelError() { delete this; }

void BrokerHost::WillDestroyCurrentMessageLoop() { delete this; }

}  // namespace edk
}  // namespace mojo