// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/mock_host_resolver.h" #include <string> #include <vector> #include "base/bind.h" #include "base/memory/ref_counted.h" #include "base/message_loop/message_loop.h" #include "base/stl_util.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/threading/platform_thread.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/host_cache.h" #if defined(OS_WIN) #include "net/base/winsock_init.h" #endif namespace net { namespace { // Cache size for the MockCachingHostResolver. const unsigned kMaxCacheEntries = 100; // TTL for the successful resolutions. Failures are not cached. const unsigned kCacheEntryTTLSeconds = 60; } // namespace int ParseAddressList(const std::string& host_list, const std::string& canonical_name, AddressList* addrlist) { *addrlist = AddressList(); std::vector<std::string> addresses; base::SplitString(host_list, ',', &addresses); addrlist->set_canonical_name(canonical_name); for (size_t index = 0; index < addresses.size(); ++index) { IPAddressNumber ip_number; if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) { LOG(WARNING) << "Not a supported IP literal: " << addresses[index]; return ERR_UNEXPECTED; } addrlist->push_back(IPEndPoint(ip_number, -1)); } return OK; } struct MockHostResolverBase::Request { Request(const RequestInfo& req_info, AddressList* addr, const CompletionCallback& cb) : info(req_info), addresses(addr), callback(cb) {} RequestInfo info; AddressList* addresses; CompletionCallback callback; }; MockHostResolverBase::~MockHostResolverBase() { STLDeleteValues(&requests_); } int MockHostResolverBase::Resolve(const RequestInfo& info, RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* handle, const BoundNetLog& net_log) { DCHECK(CalledOnValidThread()); last_request_priority_ = priority; num_resolve_++; size_t id = next_request_id_++; int rv = ResolveFromIPLiteralOrCache(info, addresses); if (rv != ERR_DNS_CACHE_MISS) { return rv; } if (synchronous_mode_) { return ResolveProc(id, info, addresses); } // Store the request for asynchronous resolution Request* req = new Request(info, addresses, callback); requests_[id] = req; if (handle) *handle = reinterpret_cast<RequestHandle>(id); if (!ondemand_mode_) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id)); } return ERR_IO_PENDING; } int MockHostResolverBase::ResolveFromCache(const RequestInfo& info, AddressList* addresses, const BoundNetLog& net_log) { num_resolve_from_cache_++; DCHECK(CalledOnValidThread()); next_request_id_++; int rv = ResolveFromIPLiteralOrCache(info, addresses); return rv; } void MockHostResolverBase::CancelRequest(RequestHandle handle) { DCHECK(CalledOnValidThread()); size_t id = reinterpret_cast<size_t>(handle); RequestMap::iterator it = requests_.find(id); if (it != requests_.end()) { scoped_ptr<Request> req(it->second); requests_.erase(it); } else { NOTREACHED() << "CancelRequest must NOT be called after request is " "complete or canceled."; } } HostCache* MockHostResolverBase::GetHostCache() { return cache_.get(); } void MockHostResolverBase::ResolveAllPending() { DCHECK(CalledOnValidThread()); DCHECK(ondemand_mode_); for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first)); } } // start id from 1 to distinguish from NULL RequestHandle MockHostResolverBase::MockHostResolverBase(bool use_caching) : last_request_priority_(DEFAULT_PRIORITY), synchronous_mode_(false), ondemand_mode_(false), next_request_id_(1), num_resolve_(0), num_resolve_from_cache_(0) { rules_ = CreateCatchAllHostResolverProc(); if (use_caching) { cache_.reset(new HostCache(kMaxCacheEntries)); } } int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info, AddressList* addresses) { IPAddressNumber ip; if (ParseIPLiteralToNumber(info.hostname(), &ip)) { *addresses = AddressList::CreateFromIPAddress(ip, info.port()); if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME) addresses->SetDefaultCanonicalName(); return OK; } int rv = ERR_DNS_CACHE_MISS; if (cache_.get() && info.allow_cached_response()) { HostCache::Key key(info.hostname(), info.address_family(), info.host_resolver_flags()); const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now()); if (entry) { rv = entry->error; if (rv == OK) *addresses = AddressList::CopyWithPort(entry->addrlist, info.port()); } } return rv; } int MockHostResolverBase::ResolveProc(size_t id, const RequestInfo& info, AddressList* addresses) { AddressList addr; int rv = rules_->Resolve(info.hostname(), info.address_family(), info.host_resolver_flags(), &addr, NULL); if (cache_.get()) { HostCache::Key key(info.hostname(), info.address_family(), info.host_resolver_flags()); // Storing a failure with TTL 0 so that it overwrites previous value. base::TimeDelta ttl; if (rv == OK) ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl); } if (rv == OK) *addresses = AddressList::CopyWithPort(addr, info.port()); return rv; } void MockHostResolverBase::ResolveNow(size_t id) { RequestMap::iterator it = requests_.find(id); if (it == requests_.end()) return; // was canceled scoped_ptr<Request> req(it->second); requests_.erase(it); int rv = ResolveProc(id, req->info, req->addresses); if (!req->callback.is_null()) req->callback.Run(rv); } //----------------------------------------------------------------------------- struct RuleBasedHostResolverProc::Rule { enum ResolverType { kResolverTypeFail, kResolverTypeSystem, kResolverTypeIPLiteral, }; ResolverType resolver_type; std::string host_pattern; AddressFamily address_family; HostResolverFlags host_resolver_flags; std::string replacement; std::string canonical_name; int latency_ms; // In milliseconds. Rule(ResolverType resolver_type, const std::string& host_pattern, AddressFamily address_family, HostResolverFlags host_resolver_flags, const std::string& replacement, const std::string& canonical_name, int latency_ms) : resolver_type(resolver_type), host_pattern(host_pattern), address_family(address_family), host_resolver_flags(host_resolver_flags), replacement(replacement), canonical_name(canonical_name), latency_ms(latency_ms) {} }; RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous) : HostResolverProc(previous) { } void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern, const std::string& replacement) { AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED, replacement); } void RuleBasedHostResolverProc::AddRuleForAddressFamily( const std::string& host_pattern, AddressFamily address_family, const std::string& replacement) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags, replacement, std::string(), 0); rules_.push_back(rule); } void RuleBasedHostResolverProc::AddIPLiteralRule( const std::string& host_pattern, const std::string& ip_literal, const std::string& canonical_name) { // Literals are always resolved to themselves by HostResolverImpl, // consequently we do not support remapping them. IPAddressNumber ip_number; DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number)); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; if (!canonical_name.empty()) flags |= HOST_RESOLVER_CANONNAME; Rule rule(Rule::kResolverTypeIPLiteral, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name, 0); rules_.push_back(rule); } void RuleBasedHostResolverProc::AddRuleWithLatency( const std::string& host_pattern, const std::string& replacement, int latency_ms) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, replacement, std::string(), latency_ms); rules_.push_back(rule); } void RuleBasedHostResolverProc::AllowDirectLookup( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(), std::string(), 0); rules_.push_back(rule); } void RuleBasedHostResolverProc::AddSimulatedFailure( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(), std::string(), 0); rules_.push_back(rule); } void RuleBasedHostResolverProc::ClearRules() { rules_.clear(); } int RuleBasedHostResolverProc::Resolve(const std::string& host, AddressFamily address_family, HostResolverFlags host_resolver_flags, AddressList* addrlist, int* os_error) { RuleList::iterator r; for (r = rules_.begin(); r != rules_.end(); ++r) { bool matches_address_family = r->address_family == ADDRESS_FAMILY_UNSPECIFIED || r->address_family == address_family; // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on // whether a rule matches. HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY; // Flags match if all of the bitflags in host_resolver_flags are enabled // in the rule's host_resolver_flags. However, the rule may have additional // flags specified, in which case the flags should still be considered a // match. bool matches_flags = (r->host_resolver_flags & flags) == flags; if (matches_flags && matches_address_family && MatchPattern(host, r->host_pattern)) { if (r->latency_ms != 0) { base::PlatformThread::Sleep( base::TimeDelta::FromMilliseconds(r->latency_ms)); } // Remap to a new host. const std::string& effective_host = r->replacement.empty() ? host : r->replacement; // Apply the resolving function to the remapped hostname. switch (r->resolver_type) { case Rule::kResolverTypeFail: return ERR_NAME_NOT_RESOLVED; case Rule::kResolverTypeSystem: #if defined(OS_WIN) net::EnsureWinsockInit(); #endif return SystemHostResolverCall(effective_host, address_family, host_resolver_flags, addrlist, os_error); case Rule::kResolverTypeIPLiteral: return ParseAddressList(effective_host, r->canonical_name, addrlist); default: NOTREACHED(); return ERR_UNEXPECTED; } } } return ResolveUsingPrevious(host, address_family, host_resolver_flags, addrlist, os_error); } RuleBasedHostResolverProc::~RuleBasedHostResolverProc() { } RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() { RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL); catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost"); // Next add a rules-based layer the use controls. return new RuleBasedHostResolverProc(catchall); } //----------------------------------------------------------------------------- int HangingHostResolver::Resolve(const RequestInfo& info, RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, const BoundNetLog& net_log) { return ERR_IO_PENDING; } int HangingHostResolver::ResolveFromCache(const RequestInfo& info, AddressList* addresses, const BoundNetLog& net_log) { return ERR_DNS_CACHE_MISS; } //----------------------------------------------------------------------------- ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {} ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc( HostResolverProc* proc) { Init(proc); } ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() { HostResolverProc* old_proc = HostResolverProc::SetDefault(previous_proc_.get()); // The lifetimes of multiple instances must be nested. CHECK_EQ(old_proc, current_proc_); } void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) { current_proc_ = proc; previous_proc_ = HostResolverProc::SetDefault(current_proc_.get()); current_proc_->SetLastProc(previous_proc_.get()); } } // namespace net