// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/protocol/connection_tester.h"

#include "base/bind.h"
#include "base/message_loop/message_loop.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace remoting {
namespace protocol {

StreamConnectionTester::StreamConnectionTester(net::StreamSocket* client_socket,
                                               net::StreamSocket* host_socket,
                                               int message_size,
                                               int message_count)
    : message_loop_(base::MessageLoop::current()),
      host_socket_(host_socket),
      client_socket_(client_socket),
      message_size_(message_size),
      test_data_size_(message_size * message_count),
      done_(false),
      write_errors_(0),
      read_errors_(0) {
}

StreamConnectionTester::~StreamConnectionTester() {
}

void StreamConnectionTester::Start() {
  InitBuffers();
  DoRead();
  DoWrite();
}

void StreamConnectionTester::CheckResults() {
  EXPECT_EQ(0, write_errors_);
  EXPECT_EQ(0, read_errors_);

  ASSERT_EQ(test_data_size_, input_buffer_->offset());

  output_buffer_->SetOffset(0);
  ASSERT_EQ(test_data_size_, output_buffer_->size());

  EXPECT_EQ(0, memcmp(output_buffer_->data(),
                      input_buffer_->StartOfBuffer(), test_data_size_));
}

void StreamConnectionTester::Done() {
  done_ = true;
  message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
}

void StreamConnectionTester::InitBuffers() {
  output_buffer_ = new net::DrainableIOBuffer(
      new net::IOBuffer(test_data_size_), test_data_size_);
  for (int i = 0; i < test_data_size_; ++i) {
    output_buffer_->data()[i] = static_cast<char>(i);
  }

  input_buffer_ = new net::GrowableIOBuffer();
}

void StreamConnectionTester::DoWrite() {
  int result = 1;
  while (result > 0) {
    if (output_buffer_->BytesRemaining() == 0)
      break;

    int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
                                  message_size_);
    result = client_socket_->Write(
        output_buffer_.get(),
        bytes_to_write,
        base::Bind(&StreamConnectionTester::OnWritten, base::Unretained(this)));
    HandleWriteResult(result);
  }
}

void StreamConnectionTester::OnWritten(int result) {
  HandleWriteResult(result);
  DoWrite();
}

void StreamConnectionTester::HandleWriteResult(int result) {
  if (result <= 0 && result != net::ERR_IO_PENDING) {
    LOG(ERROR) << "Received error " << result << " when trying to write";
    write_errors_++;
    Done();
  } else if (result > 0) {
    output_buffer_->DidConsume(result);
  }
}

void StreamConnectionTester::DoRead() {
  int result = 1;
  while (result > 0) {
    input_buffer_->SetCapacity(input_buffer_->offset() + message_size_);
    result = host_socket_->Read(
        input_buffer_.get(),
        message_size_,
        base::Bind(&StreamConnectionTester::OnRead, base::Unretained(this)));
    HandleReadResult(result);
  };
}

void StreamConnectionTester::OnRead(int result) {
  HandleReadResult(result);
  if (!done_)
    DoRead();  // Don't try to read again when we are done reading.
}

void StreamConnectionTester::HandleReadResult(int result) {
  if (result <= 0 && result != net::ERR_IO_PENDING) {
    LOG(ERROR) << "Received error " << result << " when trying to read";
    read_errors_++;
    Done();
  } else if (result > 0) {
    // Allocate memory for the next read.
    input_buffer_->set_offset(input_buffer_->offset() + result);
    if (input_buffer_->offset() == test_data_size_)
      Done();
  }
}

DatagramConnectionTester::DatagramConnectionTester(net::Socket* client_socket,
                                                   net::Socket* host_socket,
                                                   int message_size,
                                                   int message_count,
                                                   int delay_ms)
    : message_loop_(base::MessageLoop::current()),
      host_socket_(host_socket),
      client_socket_(client_socket),
      message_size_(message_size),
      message_count_(message_count),
      delay_ms_(delay_ms),
      done_(false),
      write_errors_(0),
      read_errors_(0),
      packets_sent_(0),
      packets_received_(0),
      bad_packets_received_(0) {
  sent_packets_.resize(message_count_);
}

DatagramConnectionTester::~DatagramConnectionTester() {
}

void DatagramConnectionTester::Start() {
  DoRead();
  DoWrite();
}

void DatagramConnectionTester::CheckResults() {
  EXPECT_EQ(0, write_errors_);
  EXPECT_EQ(0, read_errors_);

  EXPECT_EQ(0, bad_packets_received_);

  // Verify that we've received at least one packet.
  EXPECT_GT(packets_received_, 0);
  VLOG(0) << "Received " << packets_received_ << " packets out of "
          << message_count_;
}

void DatagramConnectionTester::Done() {
  done_ = true;
  message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
}

void DatagramConnectionTester::DoWrite() {
  if (packets_sent_ >= message_count_) {
    Done();
    return;
  }

  scoped_refptr<net::IOBuffer> packet(new net::IOBuffer(message_size_));
  for (int i = 0; i < message_size_; ++i) {
    packet->data()[i] = static_cast<char>(i);
  }
  sent_packets_[packets_sent_] = packet;
  // Put index of this packet in the beginning of the packet body.
  memcpy(packet->data(), &packets_sent_, sizeof(packets_sent_));

  int result = client_socket_->Write(
      packet.get(),
      message_size_,
      base::Bind(&DatagramConnectionTester::OnWritten, base::Unretained(this)));
  HandleWriteResult(result);
}

void DatagramConnectionTester::OnWritten(int result) {
  HandleWriteResult(result);
}

void DatagramConnectionTester::HandleWriteResult(int result) {
  if (result <= 0 && result != net::ERR_IO_PENDING) {
    LOG(ERROR) << "Received error " << result << " when trying to write";
    write_errors_++;
    Done();
  } else if (result > 0) {
    EXPECT_EQ(message_size_, result);
    packets_sent_++;
    message_loop_->PostDelayedTask(
        FROM_HERE,
        base::Bind(&DatagramConnectionTester::DoWrite, base::Unretained(this)),
        base::TimeDelta::FromMilliseconds(delay_ms_));
  }
}

void DatagramConnectionTester::DoRead() {
  int result = 1;
  while (result > 0) {
    int kReadSize = message_size_ * 2;
    read_buffer_ = new net::IOBuffer(kReadSize);

    result = host_socket_->Read(
        read_buffer_.get(),
        kReadSize,
        base::Bind(&DatagramConnectionTester::OnRead, base::Unretained(this)));
    HandleReadResult(result);
  };
}

void DatagramConnectionTester::OnRead(int result) {
  HandleReadResult(result);
  DoRead();
}

void DatagramConnectionTester::HandleReadResult(int result) {
  if (result <= 0 && result != net::ERR_IO_PENDING) {
    // Error will be received after the socket is closed.
    LOG(ERROR) << "Received error " << result << " when trying to read";
    read_errors_++;
    Done();
  } else if (result > 0) {
    packets_received_++;
    if (message_size_ != result) {
      // Invalid packet size;
      bad_packets_received_++;
    } else {
      // Validate packet body.
      int packet_id;
      memcpy(&packet_id, read_buffer_->data(), sizeof(packet_id));
      if (packet_id < 0 || packet_id >= message_count_) {
        bad_packets_received_++;
      } else {
        if (memcmp(read_buffer_->data(), sent_packets_[packet_id]->data(),
                   message_size_) != 0)
          bad_packets_received_++;
      }
    }
  }
}

}  // namespace protocol
}  // namespace remoting