// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "mojo/edk/system/channel.h"

#include <stdint.h>
#include <windows.h>

#include <algorithm>
#include <deque>
#include <limits>
#include <memory>

#include "base/bind.h"
#include "base/location.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop.h"
#include "base/synchronization/lock.h"
#include "base/task_runner.h"
#include "base/win/win_util.h"
#include "mojo/edk/embedder/platform_handle_vector.h"

namespace mojo {
namespace edk {

namespace {

// A view over a Channel::Message object. The write queue uses these since
// large messages may need to be sent in chunks.
class MessageView {
 public:
  // Owns |message|. |offset| indexes the first unsent byte in the message.
  MessageView(Channel::MessagePtr message, size_t offset)
      : message_(std::move(message)),
        offset_(offset) {
    DCHECK_GT(message_->data_num_bytes(), offset_);
  }

  MessageView(MessageView&& other) { *this = std::move(other); }

  MessageView& operator=(MessageView&& other) {
    message_ = std::move(other.message_);
    offset_ = other.offset_;
    return *this;
  }

  ~MessageView() {}

  const void* data() const {
    return static_cast<const char*>(message_->data()) + offset_;
  }

  size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; }

  size_t data_offset() const { return offset_; }
  void advance_data_offset(size_t num_bytes) {
    DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes);
    offset_ += num_bytes;
  }

  Channel::MessagePtr TakeChannelMessage() { return std::move(message_); }

 private:
  Channel::MessagePtr message_;
  size_t offset_;

  DISALLOW_COPY_AND_ASSIGN(MessageView);
};

class ChannelWin : public Channel,
                   public base::MessageLoop::DestructionObserver,
                   public base::MessageLoopForIO::IOHandler {
 public:
  ChannelWin(Delegate* delegate,
             ScopedPlatformHandle handle,
             scoped_refptr<base::TaskRunner> io_task_runner)
      : Channel(delegate),
        self_(this),
        handle_(std::move(handle)),
        io_task_runner_(io_task_runner) {
    CHECK(handle_.is_valid());

    wait_for_connect_ = handle_.get().needs_connection;
  }

  void Start() override {
    io_task_runner_->PostTask(
        FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this));
  }

  void ShutDownImpl() override {
    // Always shut down asynchronously when called through the public interface.
    io_task_runner_->PostTask(
        FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this));
  }

  void Write(MessagePtr message) override {
    bool write_error = false;
    {
      base::AutoLock lock(write_lock_);
      if (reject_writes_)
        return;

      bool write_now = !delay_writes_ && outgoing_messages_.empty();
      outgoing_messages_.emplace_back(std::move(message), 0);

      if (write_now && !WriteNoLock(outgoing_messages_.front()))
        reject_writes_ = write_error = true;
    }
    if (write_error) {
      // Do not synchronously invoke OnError(). Write() may have been called by
      // the delegate and we don't want to re-enter it.
      io_task_runner_->PostTask(FROM_HERE,
                                base::Bind(&ChannelWin::OnError, this));
    }
  }

  void LeakHandle() override {
    DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
    leak_handle_ = true;
  }

  bool GetReadPlatformHandles(
      size_t num_handles,
      const void* extra_header,
      size_t extra_header_size,
      ScopedPlatformHandleVectorPtr* handles) override {
    if (num_handles > std::numeric_limits<uint16_t>::max())
      return false;
    using HandleEntry = Channel::Message::HandleEntry;
    size_t handles_size = sizeof(HandleEntry) * num_handles;
    if (handles_size > extra_header_size)
      return false;
    DCHECK(extra_header);
    handles->reset(new PlatformHandleVector(num_handles));
    const HandleEntry* extra_header_handles =
        reinterpret_cast<const HandleEntry*>(extra_header);
    for (size_t i = 0; i < num_handles; i++) {
      (*handles)->at(i).handle =
          base::win::Uint32ToHandle(extra_header_handles[i].handle);
    }
    return true;
  }

 private:
  // May run on any thread.
  ~ChannelWin() override {}

  void StartOnIOThread() {
    base::MessageLoop::current()->AddDestructionObserver(this);
    base::MessageLoopForIO::current()->RegisterIOHandler(
        handle_.get().handle, this);

    if (wait_for_connect_) {
      BOOL ok = ConnectNamedPipe(handle_.get().handle,
                                 &connect_context_.overlapped);
      if (ok) {
        PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
        OnError();
        return;
      }

      const DWORD err = GetLastError();
      switch (err) {
        case ERROR_PIPE_CONNECTED:
          wait_for_connect_ = false;
          break;
        case ERROR_IO_PENDING:
          AddRef();
          return;
        case ERROR_NO_DATA:
          OnError();
          return;
      }
    }

    // Now that we have registered our IOHandler, we can start writing.
    {
      base::AutoLock lock(write_lock_);
      if (delay_writes_) {
        delay_writes_ = false;
        WriteNextNoLock();
      }
    }

    // Keep this alive in case we synchronously run shutdown.
    scoped_refptr<ChannelWin> keep_alive(this);
    ReadMore(0);
  }

  void ShutDownOnIOThread() {
    base::MessageLoop::current()->RemoveDestructionObserver(this);

    // BUG(crbug.com/583525): This function is expected to be called once, and
    // |handle_| should be valid at this point.
    CHECK(handle_.is_valid());
    CancelIo(handle_.get().handle);
    if (leak_handle_)
      ignore_result(handle_.release());
    handle_.reset();

    // May destroy the |this| if it was the last reference.
    self_ = nullptr;
  }

  // base::MessageLoop::DestructionObserver:
  void WillDestroyCurrentMessageLoop() override {
    DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
    if (self_)
      ShutDownOnIOThread();
  }

  // base::MessageLoop::IOHandler:
  void OnIOCompleted(base::MessageLoopForIO::IOContext* context,
                     DWORD bytes_transfered,
                     DWORD error) override {
    if (error != ERROR_SUCCESS) {
      OnError();
    } else if (context == &connect_context_) {
      DCHECK(wait_for_connect_);
      wait_for_connect_ = false;
      ReadMore(0);

      base::AutoLock lock(write_lock_);
      if (delay_writes_) {
        delay_writes_ = false;
        WriteNextNoLock();
      }
    } else if (context == &read_context_) {
      OnReadDone(static_cast<size_t>(bytes_transfered));
    } else {
      CHECK(context == &write_context_);
      OnWriteDone(static_cast<size_t>(bytes_transfered));
    }
    Release();  // Balancing reference taken after ReadFile / WriteFile.
  }

  void OnReadDone(size_t bytes_read) {
    if (bytes_read > 0) {
      size_t next_read_size = 0;
      if (OnReadComplete(bytes_read, &next_read_size)) {
        ReadMore(next_read_size);
      } else {
        OnError();
      }
    } else if (bytes_read == 0) {
      OnError();
    }
  }

  void OnWriteDone(size_t bytes_written) {
    if (bytes_written == 0)
      return;

    bool write_error = false;
    {
      base::AutoLock lock(write_lock_);

      DCHECK(!outgoing_messages_.empty());

      MessageView& message_view = outgoing_messages_.front();
      message_view.advance_data_offset(bytes_written);
      if (message_view.data_num_bytes() == 0) {
        Channel::MessagePtr message = message_view.TakeChannelMessage();
        outgoing_messages_.pop_front();

        // Clear any handles so they don't get closed on destruction.
        ScopedPlatformHandleVectorPtr handles = message->TakeHandles();
        if (handles)
          handles->clear();
      }

      if (!WriteNextNoLock())
        reject_writes_ = write_error = true;
    }
    if (write_error)
      OnError();
  }

  void ReadMore(size_t next_read_size_hint) {
    size_t buffer_capacity = next_read_size_hint;
    char* buffer = GetReadBuffer(&buffer_capacity);
    DCHECK_GT(buffer_capacity, 0u);

    BOOL ok = ReadFile(handle_.get().handle,
                       buffer,
                       static_cast<DWORD>(buffer_capacity),
                       NULL,
                       &read_context_.overlapped);

    if (ok || GetLastError() == ERROR_IO_PENDING) {
      AddRef();  // Will be balanced in OnIOCompleted
    } else {
      OnError();
    }
  }

  // Attempts to write a message directly to the channel. If the full message
  // cannot be written, it's queued and a wait is initiated to write the message
  // ASAP on the I/O thread.
  bool WriteNoLock(const MessageView& message_view) {
    BOOL ok = WriteFile(handle_.get().handle,
                        message_view.data(),
                        static_cast<DWORD>(message_view.data_num_bytes()),
                        NULL,
                        &write_context_.overlapped);

    if (ok || GetLastError() == ERROR_IO_PENDING) {
      AddRef();  // Will be balanced in OnIOCompleted.
      return true;
    }
    return false;
  }

  bool WriteNextNoLock() {
    if (outgoing_messages_.empty())
      return true;
    return WriteNoLock(outgoing_messages_.front());
  }

  // Keeps the Channel alive at least until explicit shutdown on the IO thread.
  scoped_refptr<Channel> self_;

  ScopedPlatformHandle handle_;
  scoped_refptr<base::TaskRunner> io_task_runner_;

  base::MessageLoopForIO::IOContext connect_context_;
  base::MessageLoopForIO::IOContext read_context_;
  base::MessageLoopForIO::IOContext write_context_;

  // Protects |reject_writes_| and |outgoing_messages_|.
  base::Lock write_lock_;

  bool delay_writes_ = true;

  bool reject_writes_ = false;
  std::deque<MessageView> outgoing_messages_;

  bool wait_for_connect_;

  bool leak_handle_ = false;

  DISALLOW_COPY_AND_ASSIGN(ChannelWin);
};

}  // namespace

// static
scoped_refptr<Channel> Channel::Create(
    Delegate* delegate,
    ConnectionParams connection_params,
    scoped_refptr<base::TaskRunner> io_task_runner) {
  return new ChannelWin(delegate, connection_params.TakeChannelHandle(),
                        io_task_runner);
}

}  // namespace edk
}  // namespace mojo