// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/dns/dns_session.h"

#include <list>

#include "base/bind.h"
#include "base/memory/scoped_ptr.h"
#include "base/rand_util.h"
#include "base/stl_util.h"
#include "net/base/net_log.h"
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_socket_pool.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace {

class TestClientSocketFactory : public ClientSocketFactory {
 public:
  virtual ~TestClientSocketFactory();

  virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
      DatagramSocket::BindType bind_type,
      const RandIntCallback& rand_int_cb,
      net::NetLog* net_log,
      const net::NetLog::Source& source) OVERRIDE;

  virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
      const AddressList& addresses,
      NetLog*, const NetLog::Source&) OVERRIDE {
    NOTIMPLEMENTED();
    return scoped_ptr<StreamSocket>();
  }

  virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
      scoped_ptr<ClientSocketHandle> transport_socket,
      const HostPortPair& host_and_port,
      const SSLConfig& ssl_config,
      const SSLClientSocketContext& context) OVERRIDE {
    NOTIMPLEMENTED();
    return scoped_ptr<SSLClientSocket>();
  }

  virtual void ClearSSLSessionCache() OVERRIDE {
    NOTIMPLEMENTED();
  }

 private:
  std::list<SocketDataProvider*> data_providers_;
};

struct PoolEvent {
  enum { ALLOCATE, FREE } action;
  unsigned server_index;
};

class DnsSessionTest : public testing::Test {
 public:
  void OnSocketAllocated(unsigned server_index);
  void OnSocketFreed(unsigned server_index);

 protected:
  void Initialize(unsigned num_servers);
  scoped_ptr<DnsSession::SocketLease> Allocate(unsigned server_index);
  bool DidAllocate(unsigned server_index);
  bool DidFree(unsigned server_index);
  bool NoMoreEvents();

  DnsConfig config_;
  scoped_ptr<TestClientSocketFactory> test_client_socket_factory_;
  scoped_refptr<DnsSession> session_;
  NetLog::Source source_;

 private:
  bool ExpectEvent(const PoolEvent& event);
  std::list<PoolEvent> events_;
};

class MockDnsSocketPool : public DnsSocketPool {
 public:
  MockDnsSocketPool(ClientSocketFactory* factory, DnsSessionTest* test)
     : DnsSocketPool(factory), test_(test) { }

  virtual ~MockDnsSocketPool() { }

  virtual void Initialize(
      const std::vector<IPEndPoint>* nameservers,
      NetLog* net_log) OVERRIDE {
    InitializeInternal(nameservers, net_log);
  }

  virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
      unsigned server_index) OVERRIDE {
    test_->OnSocketAllocated(server_index);
    return CreateConnectedSocket(server_index);
  }

  virtual void FreeSocket(
      unsigned server_index,
      scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
    test_->OnSocketFreed(server_index);
  }

 private:
  DnsSessionTest* test_;
};

void DnsSessionTest::Initialize(unsigned num_servers) {
  CHECK(num_servers < 256u);
  config_.nameservers.clear();
  IPAddressNumber dns_ip;
  bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip);
  EXPECT_TRUE(rv);
  for (unsigned char i = 0; i < num_servers; ++i) {
    dns_ip[3] = i;
    IPEndPoint dns_endpoint(dns_ip, dns_protocol::kDefaultPort);
    config_.nameservers.push_back(dns_endpoint);
  }

  test_client_socket_factory_.reset(new TestClientSocketFactory());

  DnsSocketPool* dns_socket_pool =
      new MockDnsSocketPool(test_client_socket_factory_.get(), this);

  session_ = new DnsSession(config_,
                            scoped_ptr<DnsSocketPool>(dns_socket_pool),
                            base::Bind(&base::RandInt),
                            NULL /* NetLog */);

  events_.clear();
}

scoped_ptr<DnsSession::SocketLease> DnsSessionTest::Allocate(
    unsigned server_index) {
  return session_->AllocateSocket(server_index, source_);
}

bool DnsSessionTest::DidAllocate(unsigned server_index) {
  PoolEvent expected_event = { PoolEvent::ALLOCATE, server_index };
  return ExpectEvent(expected_event);
}

bool DnsSessionTest::DidFree(unsigned server_index) {
  PoolEvent expected_event = { PoolEvent::FREE, server_index };
  return ExpectEvent(expected_event);
}

bool DnsSessionTest::NoMoreEvents() {
  return events_.empty();
}

void DnsSessionTest::OnSocketAllocated(unsigned server_index) {
  PoolEvent event = { PoolEvent::ALLOCATE, server_index };
  events_.push_back(event);
}

void DnsSessionTest::OnSocketFreed(unsigned server_index) {
  PoolEvent event = { PoolEvent::FREE, server_index };
  events_.push_back(event);
}

bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) {
  if (events_.empty()) {
    return false;
  }

  const PoolEvent actual = events_.front();
  if ((expected.action != actual.action)
      || (expected.server_index != actual.server_index)) {
    return false;
  }
  events_.pop_front();

  return true;
}

scoped_ptr<DatagramClientSocket>
TestClientSocketFactory::CreateDatagramClientSocket(
    DatagramSocket::BindType bind_type,
    const RandIntCallback& rand_int_cb,
    net::NetLog* net_log,
    const net::NetLog::Source& source) {
  // We're not actually expecting to send or receive any data, so use the
  // simplest SocketDataProvider with no data supplied.
  SocketDataProvider* data_provider = new StaticSocketDataProvider();
  data_providers_.push_back(data_provider);
  scoped_ptr<MockUDPClientSocket> socket(
      new MockUDPClientSocket(data_provider, net_log));
  data_provider->set_socket(socket.get());
  return socket.PassAs<DatagramClientSocket>();
}

TestClientSocketFactory::~TestClientSocketFactory() {
  STLDeleteElements(&data_providers_);
}

TEST_F(DnsSessionTest, AllocateFree) {
  scoped_ptr<DnsSession::SocketLease> lease1, lease2;

  Initialize(2);
  EXPECT_TRUE(NoMoreEvents());

  lease1 = Allocate(0);
  EXPECT_TRUE(DidAllocate(0));
  EXPECT_TRUE(NoMoreEvents());

  lease2 = Allocate(1);
  EXPECT_TRUE(DidAllocate(1));
  EXPECT_TRUE(NoMoreEvents());

  lease1.reset();
  EXPECT_TRUE(DidFree(0));
  EXPECT_TRUE(NoMoreEvents());

  lease2.reset();
  EXPECT_TRUE(DidFree(1));
  EXPECT_TRUE(NoMoreEvents());
}

// Expect default calculated timeout to be within 10ms of in DnsConfig.
TEST_F(DnsSessionTest, HistogramTimeoutNormal) {
  Initialize(2);
  base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout;
  EXPECT_LT(timeoutDelta.InMilliseconds(), 10);
}

// Expect short calculated timeout to be within 10ms of in DnsConfig.
TEST_F(DnsSessionTest, HistogramTimeoutShort) {
  config_.timeout = base::TimeDelta::FromMilliseconds(15);
  Initialize(2);
  base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout;
  EXPECT_LT(timeoutDelta.InMilliseconds(), 10);
}

// Expect long calculated timeout to be equal to one in DnsConfig.
TEST_F(DnsSessionTest, HistogramTimeoutLong) {
  config_.timeout = base::TimeDelta::FromSeconds(15);
  Initialize(2);
  base::TimeDelta timeout = session_->NextTimeout(0, 0);
  EXPECT_EQ(config_.timeout.InMilliseconds(), timeout.InMilliseconds());
}

}  // namespace

} // namespace net