// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/udp/udp_client_socket.h"
#include "net/udp/udp_server_socket.h"

#include "base/basictypes.h"
#include "base/metrics/histogram.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_test_suite.h"
#include "net/base/net_util.h"
#include "net/base/sys_addrinfo.h"
#include "net/base/test_completion_callback.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"

namespace net {

namespace {

class UDPSocketTest : public PlatformTest {
 public:
  UDPSocketTest()
      : buffer_(new IOBufferWithSize(kMaxRead)) {
  }

  // Blocks until data is read from the socket.
  std::string RecvFromSocket(UDPServerSocket* socket) {
    TestCompletionCallback callback;

    int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_,
                              &callback);
    if (rv == ERR_IO_PENDING)
      rv = callback.WaitForResult();
    if (rv < 0)
      return "";  // error!
    return std::string(buffer_->data(), rv);
  }

  // Loop until |msg| has been written to the socket or until an
  // error occurs.
  // If |address| is specified, then it is used for the destination
  // to send to. Otherwise, will send to the last socket this server
  // received from.
  int SendToSocket(UDPServerSocket* socket, std::string msg) {
    return SendToSocket(socket, msg, recv_from_address_);
  }

  int SendToSocket(UDPServerSocket* socket,
                   std::string msg,
                   const IPEndPoint& address) {
    TestCompletionCallback callback;

    int length = msg.length();
    scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
    scoped_refptr<DrainableIOBuffer> buffer(
        new DrainableIOBuffer(io_buffer, length));

    int bytes_sent = 0;
    while (buffer->BytesRemaining()) {
      int rv = socket->SendTo(buffer, buffer->BytesRemaining(),
                              address, &callback);
      if (rv == ERR_IO_PENDING)
        rv = callback.WaitForResult();
      if (rv <= 0)
        return bytes_sent > 0 ? bytes_sent : rv;
      bytes_sent += rv;
      buffer->DidConsume(rv);
    }
    return bytes_sent;
  }

  std::string ReadSocket(UDPClientSocket* socket) {
    TestCompletionCallback callback;

    int rv = socket->Read(buffer_, kMaxRead, &callback);
    if (rv == ERR_IO_PENDING)
      rv = callback.WaitForResult();
    if (rv < 0)
      return "";  // error!
    return std::string(buffer_->data(), rv);
  }

  // Loop until |msg| has been written to the socket or until an
  // error occurs.
  int WriteSocket(UDPClientSocket* socket, std::string msg) {
    TestCompletionCallback callback;

    int length = msg.length();
    scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
    scoped_refptr<DrainableIOBuffer> buffer(
        new DrainableIOBuffer(io_buffer, length));

    int bytes_sent = 0;
    while (buffer->BytesRemaining()) {
      int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback);
      if (rv == ERR_IO_PENDING)
        rv = callback.WaitForResult();
      if (rv <= 0)
        return bytes_sent > 0 ? bytes_sent : rv;
      bytes_sent += rv;
      buffer->DidConsume(rv);
    }
    return bytes_sent;
  }

 protected:
  static const int kMaxRead = 1024;
  scoped_refptr<IOBufferWithSize> buffer_;
  IPEndPoint recv_from_address_;
};

// Creates and address from an ip/port and returns it in |address|.
void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) {
  IPAddressNumber ip_number;
  bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
  if (!rv)
    return;
  *address = IPEndPoint(ip_number, port);
}

TEST_F(UDPSocketTest, Connect) {
  const int kPort = 9999;
  std::string simple_message("hello world!");

  // Setup the server to listen.
  IPEndPoint bind_address;
  CreateUDPAddress("0.0.0.0", kPort, &bind_address);
  UDPServerSocket server(NULL, NetLog::Source());
  int rv = server.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  // Setup the client.
  IPEndPoint server_address;
  CreateUDPAddress("127.0.0.1", kPort, &server_address);
  UDPClientSocket client(NULL, NetLog::Source());
  rv = client.Connect(server_address);
  EXPECT_EQ(OK, rv);

  // Client sends to the server.
  rv = WriteSocket(&client, simple_message);
  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));

  // Server waits for message.
  std::string str = RecvFromSocket(&server);
  DCHECK(simple_message == str);

  // Server echoes reply.
  rv = SendToSocket(&server, simple_message);
  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));

  // Client waits for response.
  str = ReadSocket(&client);
  DCHECK(simple_message == str);
}

// In this test, we verify that connect() on a socket will have the effect
// of filtering reads on this socket only to data read from the destination
// we connected to.
//
// The purpose of this test is that some documentation indicates that connect
// binds the client's sends to send to a particular server endpoint, but does
// not bind the client's reads to only be from that endpoint, and that we need
// to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
  const int kPort1 = 9999;
  const int kPort2 = 10000;
  std::string simple_message("hello world!");
  std::string foreign_message("BAD MESSAGE TO GET!!");

  // Setup the first server to listen.
  IPEndPoint bind_address;
  CreateUDPAddress("0.0.0.0", kPort1, &bind_address);
  UDPServerSocket server1(NULL, NetLog::Source());
  int rv = server1.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  // Setup the second server to listen.
  CreateUDPAddress("0.0.0.0", kPort2, &bind_address);
  UDPServerSocket server2(NULL, NetLog::Source());
  rv = server2.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  // Setup the client, connected to server 1.
  IPEndPoint server_address;
  CreateUDPAddress("127.0.0.1", kPort1, &server_address);
  UDPClientSocket client(NULL, NetLog::Source());
  rv = client.Connect(server_address);
  EXPECT_EQ(OK, rv);

  // Client sends to server1.
  rv = WriteSocket(&client, simple_message);
  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));

  // Server1 waits for message.
  std::string str = RecvFromSocket(&server1);
  DCHECK(simple_message == str);

  // Get the client's address.
  IPEndPoint client_address;
  rv = client.GetLocalAddress(&client_address);
  EXPECT_EQ(OK, rv);

  // Server2 sends reply.
  rv = SendToSocket(&server2, foreign_message,
                    client_address);
  EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv));

  // Server1 sends reply.
  rv = SendToSocket(&server1, simple_message,
                    client_address);
  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));

  // Client waits for response.
  str = ReadSocket(&client);
  DCHECK(simple_message == str);
}

TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
  struct TestData {
    std::string remote_address;
    std::string local_address;
    bool may_fail;
  } tests[] = {
    { "127.0.00.1", "127.0.0.1", false },
    { "192.168.1.1", "127.0.0.1", false },
    { "::1", "::1", true },
    { "2001:db8:0::42", "::1", true },
  };
  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) {
    SCOPED_TRACE(std::string("Connecting from ") +  tests[i].local_address +
                 std::string(" to ") + tests[i].remote_address);

    net::IPAddressNumber ip_number;
    net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number);
    net::IPEndPoint remote_address(ip_number, 80);
    net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
    net::IPEndPoint local_address(ip_number, 80);

    UDPClientSocket client(NULL, NetLog::Source());
    int rv = client.Connect(remote_address);
    if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
      // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
      // addresses if IPv6 is not configured.
      continue;
    }

    EXPECT_LE(ERR_IO_PENDING, rv);

    IPEndPoint fetched_local_address;
    rv = client.GetLocalAddress(&fetched_local_address);
    EXPECT_EQ(OK, rv);

    // TODO(mbelshe): figure out how to verify the IP and port.
    //                The port is dynamically generated by the udp stack.
    //                The IP is the real IP of the client, not necessarily
    //                loopback.
    //EXPECT_EQ(local_address.address(), fetched_local_address.address());

    IPEndPoint fetched_remote_address;
    rv = client.GetPeerAddress(&fetched_remote_address);
    EXPECT_EQ(OK, rv);

    EXPECT_EQ(remote_address, fetched_remote_address);
  }
}

TEST_F(UDPSocketTest, ServerGetLocalAddress) {
  IPEndPoint bind_address;
  CreateUDPAddress("127.0.0.1", 0, &bind_address);
  UDPServerSocket server(NULL, NetLog::Source());
  int rv = server.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  IPEndPoint local_address;
  rv = server.GetLocalAddress(&local_address);
  EXPECT_EQ(rv, 0);

  // Verify that port was allocated.
  EXPECT_GT(local_address.port(), 0);
  EXPECT_EQ(local_address.address(), bind_address.address());
}

TEST_F(UDPSocketTest, ServerGetPeerAddress) {
  IPEndPoint bind_address;
  CreateUDPAddress("127.0.0.1", 0, &bind_address);
  UDPServerSocket server(NULL, NetLog::Source());
  int rv = server.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  IPEndPoint peer_address;
  rv = server.GetPeerAddress(&peer_address);
  EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
}

// Close the socket while read is pending.
TEST_F(UDPSocketTest, CloseWithPendingRead) {
  IPEndPoint bind_address;
  CreateUDPAddress("127.0.0.1", 0, &bind_address);
  UDPServerSocket server(NULL, NetLog::Source());
  int rv = server.Listen(bind_address);
  EXPECT_EQ(OK, rv);

  TestCompletionCallback callback;
  IPEndPoint from;
  rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback);
  EXPECT_EQ(rv, ERR_IO_PENDING);

  server.Close();

  EXPECT_FALSE(callback.have_result());
}

}  // namespace

}  // namespace net