// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_socket_pool.h"
#include "base/logging.h"
#include "base/rand_util.h"
#include "base/stl_util.h"
#include "net/base/address_list.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/rand_callback.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/stream_socket.h"
#include "net/udp/datagram_client_socket.h"
namespace net {
namespace {
// When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
// When we allocate a socket, we ensure we have at least kAllocateMinSize
// sockets to choose from. Freed sockets are not retained.
// On Windows, we can't request specific (random) ports, since that will
// trigger firewall prompts, so request default ones, but keep a pile of
// them. Everywhere else, request fresh, random ports each time.
#if defined(OS_WIN)
const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
const unsigned kInitialPoolSize = 256;
const unsigned kAllocateMinSize = 256;
#else
const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
const unsigned kInitialPoolSize = 0;
const unsigned kAllocateMinSize = 1;
#endif
} // namespace
DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
: socket_factory_(socket_factory),
net_log_(NULL),
nameservers_(NULL),
initialized_(false) {
}
void DnsSocketPool::InitializeInternal(
const std::vector<IPEndPoint>* nameservers,
NetLog* net_log) {
DCHECK(nameservers);
DCHECK(!initialized_);
net_log_ = net_log;
nameservers_ = nameservers;
initialized_ = true;
}
scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
unsigned server_index,
const NetLog::Source& source) {
DCHECK_LT(server_index, nameservers_->size());
return scoped_ptr<StreamSocket>(
socket_factory_->CreateTransportClientSocket(
AddressList((*nameservers_)[server_index]), net_log_, source));
}
scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
unsigned server_index) {
DCHECK_LT(server_index, nameservers_->size());
scoped_ptr<DatagramClientSocket> socket;
NetLog::Source no_source;
socket = socket_factory_->CreateDatagramClientSocket(
kBindType, base::Bind(&base::RandInt), net_log_, no_source);
if (socket.get()) {
int rv = socket->Connect((*nameservers_)[server_index]);
if (rv != OK) {
VLOG(1) << "Failed to connect socket: " << rv;
socket.reset();
}
} else {
LOG(WARNING) << "Failed to create socket.";
}
return socket.Pass();
}
class NullDnsSocketPool : public DnsSocketPool {
public:
NullDnsSocketPool(ClientSocketFactory* factory)
: DnsSocketPool(factory) {
}
virtual void Initialize(
const std::vector<IPEndPoint>* nameservers,
NetLog* net_log) OVERRIDE {
InitializeInternal(nameservers, net_log);
}
virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
unsigned server_index) OVERRIDE {
return CreateConnectedSocket(server_index);
}
virtual void FreeSocket(
unsigned server_index,
scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
}
private:
DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
};
// static
scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
ClientSocketFactory* factory) {
return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
}
class DefaultDnsSocketPool : public DnsSocketPool {
public:
DefaultDnsSocketPool(ClientSocketFactory* factory)
: DnsSocketPool(factory) {
};
virtual ~DefaultDnsSocketPool();
virtual void Initialize(
const std::vector<IPEndPoint>* nameservers,
NetLog* net_log) OVERRIDE;
virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
unsigned server_index) OVERRIDE;
virtual void FreeSocket(
unsigned server_index,
scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
private:
void FillPool(unsigned server_index, unsigned size);
typedef std::vector<DatagramClientSocket*> SocketVector;
std::vector<SocketVector> pools_;
DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
};
// static
scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
ClientSocketFactory* factory) {
return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
}
void DefaultDnsSocketPool::Initialize(
const std::vector<IPEndPoint>* nameservers,
NetLog* net_log) {
InitializeInternal(nameservers, net_log);
DCHECK(pools_.empty());
const unsigned num_servers = nameservers->size();
pools_.resize(num_servers);
for (unsigned server_index = 0; server_index < num_servers; ++server_index)
FillPool(server_index, kInitialPoolSize);
}
DefaultDnsSocketPool::~DefaultDnsSocketPool() {
unsigned num_servers = pools_.size();
for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
SocketVector& pool = pools_[server_index];
STLDeleteElements(&pool);
}
}
scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
unsigned server_index) {
DCHECK_LT(server_index, pools_.size());
SocketVector& pool = pools_[server_index];
FillPool(server_index, kAllocateMinSize);
if (pool.size() == 0) {
LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
return scoped_ptr<DatagramClientSocket>();
}
if (pool.size() < kAllocateMinSize) {
LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
<< " sockets to choose from, but only have " << pool.size()
<< " in pool " << server_index << ".";
}
unsigned socket_index = base::RandInt(0, pool.size() - 1);
DatagramClientSocket* socket = pool[socket_index];
pool[socket_index] = pool.back();
pool.pop_back();
return scoped_ptr<DatagramClientSocket>(socket);
}
void DefaultDnsSocketPool::FreeSocket(
unsigned server_index,
scoped_ptr<DatagramClientSocket> socket) {
DCHECK_LT(server_index, pools_.size());
}
void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
SocketVector& pool = pools_[server_index];
for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
DatagramClientSocket* socket =
CreateConnectedSocket(server_index).release();
if (!socket)
break;
pool.push_back(socket);
}
}
} // namespace net