// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_test_util.h"
#include <string>
#include "base/bind.h"
#include "base/memory/weak_ptr.h"
#include "base/message_loop/message_loop.h"
#include "base/sys_byteorder.h"
#include "net/base/big_endian.h"
#include "net/base/dns_util.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/dns/address_sorter.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_transaction.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
namespace {
class MockAddressSorter : public AddressSorter {
public:
virtual ~MockAddressSorter() {}
virtual void Sort(const AddressList& list,
const CallbackType& callback) const OVERRIDE {
// Do nothing.
callback.Run(true, list);
}
};
// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
class MockTransaction : public DnsTransaction,
public base::SupportsWeakPtr<MockTransaction> {
public:
MockTransaction(const MockDnsClientRuleList& rules,
const std::string& hostname,
uint16 qtype,
const DnsTransactionFactory::CallbackType& callback)
: result_(MockDnsClientRule::FAIL),
hostname_(hostname),
qtype_(qtype),
callback_(callback),
started_(false),
delayed_(false) {
// Find the relevant rule which matches |qtype| and prefix of |hostname|.
for (size_t i = 0; i < rules.size(); ++i) {
const std::string& prefix = rules[i].prefix;
if ((rules[i].qtype == qtype) &&
(hostname.size() >= prefix.size()) &&
(hostname.compare(0, prefix.size(), prefix) == 0)) {
result_ = rules[i].result;
delayed_ = rules[i].delay;
break;
}
}
}
virtual const std::string& GetHostname() const OVERRIDE {
return hostname_;
}
virtual uint16 GetType() const OVERRIDE {
return qtype_;
}
virtual void Start() OVERRIDE {
EXPECT_FALSE(started_);
started_ = true;
if (delayed_)
return;
// Using WeakPtr to cleanly cancel when transaction is destroyed.
base::MessageLoop::current()->PostTask(
FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
}
void FinishDelayedTransaction() {
EXPECT_TRUE(delayed_);
delayed_ = false;
Finish();
}
bool delayed() const { return delayed_; }
private:
void Finish() {
switch (result_) {
case MockDnsClientRule::EMPTY:
case MockDnsClientRule::OK: {
std::string qname;
DNSDomainFromDot(hostname_, &qname);
DnsQuery query(0, qname, qtype_);
DnsResponse response;
char* buffer = response.io_buffer()->data();
int nbytes = query.io_buffer()->size();
memcpy(buffer, query.io_buffer()->data(), nbytes);
dns_protocol::Header* header =
reinterpret_cast<dns_protocol::Header*>(buffer);
header->flags |= dns_protocol::kFlagResponse;
if (MockDnsClientRule::OK == result_) {
const uint16 kPointerToQueryName =
static_cast<uint16>(0xc000 | sizeof(*header));
const uint32 kTTL = 86400; // One day.
// Size of RDATA which is a IPv4 or IPv6 address.
size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
net::kIPv4AddressSize : net::kIPv6AddressSize;
// 12 is the sum of sizes of the compressed name reference, TYPE,
// CLASS, TTL and RDLENGTH.
size_t answer_size = 12 + rdata_size;
// Write answer with loopback IP address.
header->ancount = base::HostToNet16(1);
BigEndianWriter writer(buffer + nbytes, answer_size);
writer.WriteU16(kPointerToQueryName);
writer.WriteU16(qtype_);
writer.WriteU16(net::dns_protocol::kClassIN);
writer.WriteU32(kTTL);
writer.WriteU16(rdata_size);
if (qtype_ == net::dns_protocol::kTypeA) {
char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
} else {
char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1 };
writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
}
nbytes += answer_size;
}
EXPECT_TRUE(response.InitParse(nbytes, query));
callback_.Run(this, OK, &response);
} break;
case MockDnsClientRule::FAIL:
callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
break;
case MockDnsClientRule::TIMEOUT:
callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
break;
default:
NOTREACHED();
break;
}
}
MockDnsClientRule::Result result_;
const std::string hostname_;
const uint16 qtype_;
DnsTransactionFactory::CallbackType callback_;
bool started_;
bool delayed_;
};
} // namespace
// A DnsTransactionFactory which creates MockTransaction.
class MockTransactionFactory : public DnsTransactionFactory {
public:
explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
: rules_(rules) {}
virtual ~MockTransactionFactory() {}
virtual scoped_ptr<DnsTransaction> CreateTransaction(
const std::string& hostname,
uint16 qtype,
const DnsTransactionFactory::CallbackType& callback,
const BoundNetLog&) OVERRIDE {
MockTransaction* transaction =
new MockTransaction(rules_, hostname, qtype, callback);
if (transaction->delayed())
delayed_transactions_.push_back(transaction->AsWeakPtr());
return scoped_ptr<DnsTransaction>(transaction);
}
void CompleteDelayedTransactions() {
DelayedTransactionList old_delayed_transactions;
old_delayed_transactions.swap(delayed_transactions_);
for (DelayedTransactionList::iterator it = old_delayed_transactions.begin();
it != old_delayed_transactions.end(); ++it) {
if (it->get())
(*it)->FinishDelayedTransaction();
}
}
private:
typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList;
MockDnsClientRuleList rules_;
DelayedTransactionList delayed_transactions_;
};
MockDnsClient::MockDnsClient(const DnsConfig& config,
const MockDnsClientRuleList& rules)
: config_(config),
factory_(new MockTransactionFactory(rules)),
address_sorter_(new MockAddressSorter()) {
}
MockDnsClient::~MockDnsClient() {}
void MockDnsClient::SetConfig(const DnsConfig& config) {
config_ = config;
}
const DnsConfig* MockDnsClient::GetConfig() const {
return config_.IsValid() ? &config_ : NULL;
}
DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
return config_.IsValid() ? factory_.get() : NULL;
}
AddressSorter* MockDnsClient::GetAddressSorter() {
return address_sorter_.get();
}
void MockDnsClient::CompleteDelayedTransactions() {
factory_->CompleteDelayedTransactions();
}
} // namespace net