// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/socket_test_util.h"
#include <algorithm>
#include <vector>
#include "base/basictypes.h"
#include "base/compiler_specific.h"
#include "base/message_loop.h"
#include "base/time.h"
#include "net/base/address_family.h"
#include "net/base/auth.h"
#include "net/base/host_resolver_proc.h"
#include "net/base/ssl_cert_request_info.h"
#include "net/base/ssl_info.h"
#include "net/http/http_network_session.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket.h"
#include "net/socket/ssl_host_info.h"
#include "testing/gtest/include/gtest/gtest.h"
#define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() "
namespace net {
namespace {
inline char AsciifyHigh(char x) {
char nybble = static_cast<char>((x >> 4) & 0x0F);
return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
}
inline char AsciifyLow(char x) {
char nybble = static_cast<char>((x >> 0) & 0x0F);
return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
}
inline char Asciify(char x) {
if ((x < 0) || !isprint(x))
return '.';
return x;
}
void DumpData(const char* data, int data_len) {
if (logging::LOG_INFO < logging::GetMinLogLevel())
return;
DVLOG(1) << "Length: " << data_len;
const char* pfx = "Data: ";
if (!data || (data_len <= 0)) {
DVLOG(1) << pfx << "<None>";
} else {
int i;
for (i = 0; i <= (data_len - 4); i += 4) {
DVLOG(1) << pfx
<< AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
<< AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
<< AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
<< AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
<< " '"
<< Asciify(data[i + 0])
<< Asciify(data[i + 1])
<< Asciify(data[i + 2])
<< Asciify(data[i + 3])
<< "'";
pfx = " ";
}
// Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
switch (data_len - i) {
case 3:
DVLOG(1) << pfx
<< AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
<< AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
<< AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
<< " '"
<< Asciify(data[i + 0])
<< Asciify(data[i + 1])
<< Asciify(data[i + 2])
<< " '";
break;
case 2:
DVLOG(1) << pfx
<< AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
<< AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
<< " '"
<< Asciify(data[i + 0])
<< Asciify(data[i + 1])
<< " '";
break;
case 1:
DVLOG(1) << pfx
<< AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
<< " '"
<< Asciify(data[i + 0])
<< " '";
break;
}
}
}
void DumpMockRead(const MockRead& r) {
if (logging::LOG_INFO < logging::GetMinLogLevel())
return;
DVLOG(1) << "Async: " << r.async
<< "\nResult: " << r.result;
DumpData(r.data, r.data_len);
const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
<< "\nTime: " << r.time_stamp.ToInternalValue();
}
} // namespace
StaticSocketDataProvider::StaticSocketDataProvider()
: reads_(NULL),
read_index_(0),
read_count_(0),
writes_(NULL),
write_index_(0),
write_count_(0) {
}
StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
size_t reads_count,
MockWrite* writes,
size_t writes_count)
: reads_(reads),
read_index_(0),
read_count_(reads_count),
writes_(writes),
write_index_(0),
write_count_(writes_count) {
}
StaticSocketDataProvider::~StaticSocketDataProvider() {}
const MockRead& StaticSocketDataProvider::PeekRead() const {
DCHECK(!at_read_eof());
return reads_[read_index_];
}
const MockWrite& StaticSocketDataProvider::PeekWrite() const {
DCHECK(!at_write_eof());
return writes_[write_index_];
}
const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
DCHECK_LT(index, read_count_);
return reads_[index];
}
const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
DCHECK_LT(index, write_count_);
return writes_[index];
}
MockRead StaticSocketDataProvider::GetNextRead() {
DCHECK(!at_read_eof());
reads_[read_index_].time_stamp = base::Time::Now();
return reads_[read_index_++];
}
MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
if (!writes_) {
// Not using mock writes; succeed synchronously.
return MockWriteResult(false, data.length());
}
DCHECK(!at_write_eof());
// Check that what we are writing matches the expectation.
// Then give the mocked return value.
net::MockWrite* w = &writes_[write_index_++];
w->time_stamp = base::Time::Now();
int result = w->result;
if (w->data) {
// Note - we can simulate a partial write here. If the expected data
// is a match, but shorter than the write actually written, that is legal.
// Example:
// Application writes "foobarbaz" (9 bytes)
// Expected write was "foo" (3 bytes)
// This is a success, and we return 3 to the application.
std::string expected_data(w->data, w->data_len);
EXPECT_GE(data.length(), expected_data.length());
std::string actual_data(data.substr(0, w->data_len));
EXPECT_EQ(expected_data, actual_data);
if (expected_data != actual_data)
return MockWriteResult(false, net::ERR_UNEXPECTED);
if (result == net::OK)
result = w->data_len;
}
return MockWriteResult(w->async, result);
}
void StaticSocketDataProvider::Reset() {
read_index_ = 0;
write_index_ = 0;
}
DynamicSocketDataProvider::DynamicSocketDataProvider()
: short_read_limit_(0),
allow_unconsumed_reads_(false) {
}
DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
MockRead DynamicSocketDataProvider::GetNextRead() {
if (reads_.empty())
return MockRead(false, ERR_UNEXPECTED);
MockRead result = reads_.front();
if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
reads_.pop_front();
} else {
result.data_len = short_read_limit_;
reads_.front().data += result.data_len;
reads_.front().data_len -= result.data_len;
}
return result;
}
void DynamicSocketDataProvider::Reset() {
reads_.clear();
}
void DynamicSocketDataProvider::SimulateRead(const char* data,
const size_t length) {
if (!allow_unconsumed_reads_) {
EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
}
reads_.push_back(MockRead(true, data, length));
}
SSLSocketDataProvider::SSLSocketDataProvider(bool async, int result)
: connect(async, result),
next_proto_status(SSLClientSocket::kNextProtoUnsupported),
was_npn_negotiated(false),
cert_request_info(NULL) {
}
SSLSocketDataProvider::~SSLSocketDataProvider() {
}
DelayedSocketData::DelayedSocketData(
int write_delay, MockRead* reads, size_t reads_count,
MockWrite* writes, size_t writes_count)
: StaticSocketDataProvider(reads, reads_count, writes, writes_count),
write_delay_(write_delay),
ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
DCHECK_GE(write_delay_, 0);
}
DelayedSocketData::DelayedSocketData(
const MockConnect& connect, int write_delay, MockRead* reads,
size_t reads_count, MockWrite* writes, size_t writes_count)
: StaticSocketDataProvider(reads, reads_count, writes, writes_count),
write_delay_(write_delay),
ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
DCHECK_GE(write_delay_, 0);
set_connect_data(connect);
}
DelayedSocketData::~DelayedSocketData() {
}
void DelayedSocketData::ForceNextRead() {
write_delay_ = 0;
CompleteRead();
}
MockRead DelayedSocketData::GetNextRead() {
if (write_delay_ > 0)
return MockRead(true, ERR_IO_PENDING);
return StaticSocketDataProvider::GetNextRead();
}
MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
// Now that our write has completed, we can allow reads to continue.
if (!--write_delay_)
MessageLoop::current()->PostDelayedTask(FROM_HERE,
factory_.NewRunnableMethod(&DelayedSocketData::CompleteRead), 100);
return rv;
}
void DelayedSocketData::Reset() {
set_socket(NULL);
factory_.RevokeAll();
StaticSocketDataProvider::Reset();
}
void DelayedSocketData::CompleteRead() {
if (socket())
socket()->OnReadComplete(GetNextRead());
}
OrderedSocketData::OrderedSocketData(
MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
: StaticSocketDataProvider(reads, reads_count, writes, writes_count),
sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
}
OrderedSocketData::OrderedSocketData(
const MockConnect& connect,
MockRead* reads, size_t reads_count,
MockWrite* writes, size_t writes_count)
: StaticSocketDataProvider(reads, reads_count, writes, writes_count),
sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
set_connect_data(connect);
}
void OrderedSocketData::EndLoop() {
// If we've already stopped the loop, don't do it again until we've advanced
// to the next sequence_number.
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()";
if (loop_stop_stage_ > 0) {
const MockRead& next_read = StaticSocketDataProvider::PeekRead();
if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
loop_stop_stage_) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": Clearing stop index";
loop_stop_stage_ = 0;
} else {
return;
}
}
// Record the sequence_number at which we stopped the loop.
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": Posting Quit at read " << read_index();
loop_stop_stage_ = sequence_number_;
if (callback_)
callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING));
}
MockRead OrderedSocketData::GetNextRead() {
factory_.RevokeAll();
blocked_ = false;
const MockRead& next_read = StaticSocketDataProvider::PeekRead();
if (next_read.sequence_number & MockRead::STOPLOOP)
EndLoop();
if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
sequence_number_++) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
<< ": Read " << read_index();
DumpMockRead(next_read);
return StaticSocketDataProvider::GetNextRead();
}
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
<< ": I/O Pending";
MockRead result = MockRead(true, ERR_IO_PENDING);
DumpMockRead(result);
blocked_ = true;
return result;
}
MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": Write " << write_index();
DumpMockRead(PeekWrite());
++sequence_number_;
if (blocked_) {
// TODO(willchan): This 100ms delay seems to work around some weirdness. We
// should probably fix the weirdness. One example is in SpdyStream,
// DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the
// SYN_REPLY causes OnResponseReceived() to get called before
// SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
MessageLoop::current()->PostDelayedTask(
FROM_HERE,
factory_.NewRunnableMethod(&OrderedSocketData::CompleteRead), 100);
}
return StaticSocketDataProvider::OnWrite(data);
}
void OrderedSocketData::Reset() {
NET_TRACE(INFO, " *** ") << "Stage "
<< sequence_number_ << ": Reset()";
sequence_number_ = 0;
loop_stop_stage_ = 0;
set_socket(NULL);
factory_.RevokeAll();
StaticSocketDataProvider::Reset();
}
void OrderedSocketData::CompleteRead() {
if (socket()) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_;
socket()->OnReadComplete(GetNextRead());
}
}
OrderedSocketData::~OrderedSocketData() {}
DeterministicSocketData::DeterministicSocketData(MockRead* reads,
size_t reads_count, MockWrite* writes, size_t writes_count)
: StaticSocketDataProvider(reads, reads_count, writes, writes_count),
sequence_number_(0),
current_read_(),
current_write_(),
stopping_sequence_number_(0),
stopped_(false),
print_debug_(false) {}
DeterministicSocketData::~DeterministicSocketData() {}
void DeterministicSocketData::Run() {
SetStopped(false);
int counter = 0;
// Continue to consume data until all data has run out, or the stopped_ flag
// has been set. Consuming data requires two separate operations -- running
// the tasks in the message loop, and explicitly invoking the read/write
// callbacks (simulating network I/O). We check our conditions between each,
// since they can change in either.
while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
if (counter % 2 == 0)
MessageLoop::current()->RunAllPending();
if (counter % 2 == 1) {
InvokeCallbacks();
}
counter++;
}
// We're done consuming new data, but it is possible there are still some
// pending callbacks which we expect to complete before returning.
while (socket_ && (socket_->write_pending() || socket_->read_pending()) &&
!stopped()) {
InvokeCallbacks();
MessageLoop::current()->RunAllPending();
}
SetStopped(false);
}
void DeterministicSocketData::RunFor(int steps) {
StopAfter(steps);
Run();
}
void DeterministicSocketData::SetStop(int seq) {
DCHECK_LT(sequence_number_, seq);
stopping_sequence_number_ = seq;
stopped_ = false;
}
void DeterministicSocketData::StopAfter(int seq) {
SetStop(sequence_number_ + seq);
}
MockRead DeterministicSocketData::GetNextRead() {
current_read_ = StaticSocketDataProvider::PeekRead();
EXPECT_LE(sequence_number_, current_read_.sequence_number);
// Synchronous read while stopped is an error
if (stopped() && !current_read_.async) {
LOG(ERROR) << "Unable to perform synchronous IO while stopped";
return MockRead(false, ERR_UNEXPECTED);
}
// Async read which will be called back in a future step.
if (sequence_number_ < current_read_.sequence_number) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": I/O Pending";
MockRead result = MockRead(false, ERR_IO_PENDING);
if (!current_read_.async) {
LOG(ERROR) << "Unable to perform synchronous read: "
<< current_read_.sequence_number
<< " at stage: " << sequence_number_;
result = MockRead(false, ERR_UNEXPECTED);
}
if (print_debug_)
DumpMockRead(result);
return result;
}
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": Read " << read_index();
if (print_debug_)
DumpMockRead(current_read_);
// Increment the sequence number if IO is complete
if (!current_read_.async)
NextStep();
DCHECK_NE(ERR_IO_PENDING, current_read_.result);
StaticSocketDataProvider::GetNextRead();
return current_read_;
}
MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
current_write_ = next_write;
// Synchronous write while stopped is an error
if (stopped() && !next_write.async) {
LOG(ERROR) << "Unable to perform synchronous IO while stopped";
return MockWriteResult(false, ERR_UNEXPECTED);
}
// Async write which will be called back in a future step.
if (sequence_number_ < next_write.sequence_number) {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": I/O Pending";
if (!next_write.async) {
LOG(ERROR) << "Unable to perform synchronous write: "
<< next_write.sequence_number << " at stage: " << sequence_number_;
return MockWriteResult(false, ERR_UNEXPECTED);
}
} else {
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
<< ": Write " << write_index();
}
if (print_debug_)
DumpMockRead(next_write);
// Move to the next step if I/O is synchronous, since the operation will
// complete when this method returns.
if (!next_write.async)
NextStep();
// This is either a sync write for this step, or an async write.
return StaticSocketDataProvider::OnWrite(data);
}
void DeterministicSocketData::Reset() {
NET_TRACE(INFO, " *** ") << "Stage "
<< sequence_number_ << ": Reset()";
sequence_number_ = 0;
StaticSocketDataProvider::Reset();
NOTREACHED();
}
void DeterministicSocketData::InvokeCallbacks() {
if (socket_ && socket_->write_pending() &&
(current_write().sequence_number == sequence_number())) {
socket_->CompleteWrite();
NextStep();
return;
}
if (socket_ && socket_->read_pending() &&
(current_read().sequence_number == sequence_number())) {
socket_->CompleteRead();
NextStep();
return;
}
}
void DeterministicSocketData::NextStep() {
// Invariant: Can never move *past* the stopping step.
DCHECK_LT(sequence_number_, stopping_sequence_number_);
sequence_number_++;
if (sequence_number_ == stopping_sequence_number_)
SetStopped(true);
}
MockClientSocketFactory::MockClientSocketFactory() {}
MockClientSocketFactory::~MockClientSocketFactory() {}
void MockClientSocketFactory::AddSocketDataProvider(
SocketDataProvider* data) {
mock_data_.Add(data);
}
void MockClientSocketFactory::AddSSLSocketDataProvider(
SSLSocketDataProvider* data) {
mock_ssl_data_.Add(data);
}
void MockClientSocketFactory::ResetNextMockIndexes() {
mock_data_.ResetNextIndex();
mock_ssl_data_.ResetNextIndex();
}
MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
size_t index) const {
DCHECK_LT(index, tcp_client_sockets_.size());
return tcp_client_sockets_[index];
}
MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
size_t index) const {
DCHECK_LT(index, ssl_client_sockets_.size());
return ssl_client_sockets_[index];
}
ClientSocket* MockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
MockTCPClientSocket* socket =
new MockTCPClientSocket(addresses, net_log, data_provider);
data_provider->set_socket(socket);
tcp_client_sockets_.push_back(socket);
return socket;
}
SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
ClientSocketHandle* transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLHostInfo* ssl_host_info,
CertVerifier* cert_verifier,
DnsCertProvenanceChecker* dns_cert_checker) {
MockSSLClientSocket* socket =
new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
ssl_host_info, mock_ssl_data_.GetNext());
ssl_client_sockets_.push_back(socket);
return socket;
}
void MockClientSocketFactory::ClearSSLSessionCache() {
}
MockClientSocket::MockClientSocket(net::NetLog* net_log)
: ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
connected_(false),
net_log_(NetLog::Source(), net_log) {
}
bool MockClientSocket::SetReceiveBufferSize(int32 size) {
return true;
}
bool MockClientSocket::SetSendBufferSize(int32 size) {
return true;
}
void MockClientSocket::Disconnect() {
connected_ = false;
}
bool MockClientSocket::IsConnected() const {
return connected_;
}
bool MockClientSocket::IsConnectedAndIdle() const {
return connected_;
}
int MockClientSocket::GetPeerAddress(AddressList* address) const {
return net::SystemHostResolverProc("192.0.2.33", ADDRESS_FAMILY_UNSPECIFIED,
0, address, NULL);
}
int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
IPAddressNumber ip;
if (!ParseIPLiteralToNumber("192.0.2.33", &ip))
return ERR_FAILED;
*address = IPEndPoint(ip, 123);
return OK;
}
const BoundNetLog& MockClientSocket::NetLog() const {
return net_log_;
}
void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
NOTREACHED();
}
void MockClientSocket::GetSSLCertRequestInfo(
net::SSLCertRequestInfo* cert_request_info) {
}
SSLClientSocket::NextProtoStatus
MockClientSocket::GetNextProto(std::string* proto) {
proto->clear();
return SSLClientSocket::kNextProtoUnsupported;
}
MockClientSocket::~MockClientSocket() {}
void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
int result) {
MessageLoop::current()->PostTask(FROM_HERE,
method_factory_.NewRunnableMethod(
&MockClientSocket::RunCallback, callback, result));
}
void MockClientSocket::RunCallback(net::CompletionCallback* callback,
int result) {
if (callback)
callback->Run(result);
}
MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
net::NetLog* net_log,
net::SocketDataProvider* data)
: MockClientSocket(net_log),
addresses_(addresses),
data_(data),
read_offset_(0),
read_data_(false, net::ERR_UNEXPECTED),
need_read_data_(true),
peer_closed_connection_(false),
pending_buf_(NULL),
pending_buf_len_(0),
pending_callback_(NULL),
was_used_to_convey_data_(false) {
DCHECK(data_);
data_->Reset();
}
int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
if (!connected_)
return net::ERR_UNEXPECTED;
// If the buffer is already in use, a read is already in progress!
DCHECK(pending_buf_ == NULL);
// Store our async IO data.
pending_buf_ = buf;
pending_buf_len_ = buf_len;
pending_callback_ = callback;
if (need_read_data_) {
read_data_ = data_->GetNextRead();
if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
// This MockRead is just a marker to instruct us to set
// peer_closed_connection_. Skip it and get the next one.
read_data_ = data_->GetNextRead();
peer_closed_connection_ = true;
}
// ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
// to complete the async IO manually later (via OnReadComplete).
if (read_data_.result == ERR_IO_PENDING) {
DCHECK(callback); // We need to be using async IO in this case.
return ERR_IO_PENDING;
}
need_read_data_ = false;
}
return CompleteRead();
}
int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
DCHECK(buf);
DCHECK_GT(buf_len, 0);
if (!connected_)
return net::ERR_UNEXPECTED;
std::string data(buf->data(), buf_len);
net::MockWriteResult write_result = data_->OnWrite(data);
was_used_to_convey_data_ = true;
if (write_result.async) {
RunCallbackAsync(callback, write_result.result);
return net::ERR_IO_PENDING;
}
return write_result.result;
}
int MockTCPClientSocket::Connect(net::CompletionCallback* callback) {
if (connected_)
return net::OK;
connected_ = true;
peer_closed_connection_ = false;
if (data_->connect_data().async) {
RunCallbackAsync(callback, data_->connect_data().result);
return net::ERR_IO_PENDING;
}
return data_->connect_data().result;
}
void MockTCPClientSocket::Disconnect() {
MockClientSocket::Disconnect();
pending_callback_ = NULL;
}
bool MockTCPClientSocket::IsConnected() const {
return connected_ && !peer_closed_connection_;
}
bool MockTCPClientSocket::IsConnectedAndIdle() const {
return IsConnected();
}
int MockTCPClientSocket::GetPeerAddress(AddressList* address) const {
if (!IsConnected())
return ERR_SOCKET_NOT_CONNECTED;
return MockClientSocket::GetPeerAddress(address);
}
bool MockTCPClientSocket::WasEverUsed() const {
return was_used_to_convey_data_;
}
bool MockTCPClientSocket::UsingTCPFastOpen() const {
return false;
}
void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
// There must be a read pending.
DCHECK(pending_buf_);
// You can't complete a read with another ERR_IO_PENDING status code.
DCHECK_NE(ERR_IO_PENDING, data.result);
// Since we've been waiting for data, need_read_data_ should be true.
DCHECK(need_read_data_);
read_data_ = data;
need_read_data_ = false;
// The caller is simulating that this IO completes right now. Don't
// let CompleteRead() schedule a callback.
read_data_.async = false;
net::CompletionCallback* callback = pending_callback_;
int rv = CompleteRead();
RunCallback(callback, rv);
}
int MockTCPClientSocket::CompleteRead() {
DCHECK(pending_buf_);
DCHECK(pending_buf_len_ > 0);
was_used_to_convey_data_ = true;
// Save the pending async IO data and reset our |pending_| state.
net::IOBuffer* buf = pending_buf_;
int buf_len = pending_buf_len_;
net::CompletionCallback* callback = pending_callback_;
pending_buf_ = NULL;
pending_buf_len_ = 0;
pending_callback_ = NULL;
int result = read_data_.result;
DCHECK(result != ERR_IO_PENDING);
if (read_data_.data) {
if (read_data_.data_len - read_offset_ > 0) {
result = std::min(buf_len, read_data_.data_len - read_offset_);
memcpy(buf->data(), read_data_.data + read_offset_, result);
read_offset_ += result;
if (read_offset_ == read_data_.data_len) {
need_read_data_ = true;
read_offset_ = 0;
}
} else {
result = 0; // EOF
}
}
if (read_data_.async) {
DCHECK(callback);
RunCallbackAsync(callback, result);
return net::ERR_IO_PENDING;
}
return result;
}
DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
net::NetLog* net_log, net::DeterministicSocketData* data)
: MockClientSocket(net_log),
write_pending_(false),
write_callback_(NULL),
write_result_(0),
read_data_(),
read_buf_(NULL),
read_buf_len_(0),
read_pending_(false),
read_callback_(NULL),
data_(data),
was_used_to_convey_data_(false) {}
DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
void DeterministicMockTCPClientSocket::CompleteWrite() {
was_used_to_convey_data_ = true;
write_pending_ = false;
write_callback_->Run(write_result_);
}
int DeterministicMockTCPClientSocket::CompleteRead() {
DCHECK_GT(read_buf_len_, 0);
DCHECK_LE(read_data_.data_len, read_buf_len_);
DCHECK(read_buf_);
was_used_to_convey_data_ = true;
if (read_data_.result == ERR_IO_PENDING)
read_data_ = data_->GetNextRead();
DCHECK_NE(ERR_IO_PENDING, read_data_.result);
// If read_data_.async is true, we do not need to wait, since this is already
// the callback. Therefore we don't even bother to check it.
int result = read_data_.result;
if (read_data_.data_len > 0) {
DCHECK(read_data_.data);
result = std::min(read_buf_len_, read_data_.data_len);
memcpy(read_buf_->data(), read_data_.data, result);
}
if (read_pending_) {
read_pending_ = false;
read_callback_->Run(result);
}
return result;
}
int DeterministicMockTCPClientSocket::Write(
net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
DCHECK(buf);
DCHECK_GT(buf_len, 0);
if (!connected_)
return net::ERR_UNEXPECTED;
std::string data(buf->data(), buf_len);
net::MockWriteResult write_result = data_->OnWrite(data);
if (write_result.async) {
write_callback_ = callback;
write_result_ = write_result.result;
DCHECK(write_callback_ != NULL);
write_pending_ = true;
return net::ERR_IO_PENDING;
}
was_used_to_convey_data_ = true;
write_pending_ = false;
return write_result.result;
}
int DeterministicMockTCPClientSocket::Read(
net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
if (!connected_)
return net::ERR_UNEXPECTED;
read_data_ = data_->GetNextRead();
// The buffer should always be big enough to contain all the MockRead data. To
// use small buffers, split the data into multiple MockReads.
DCHECK_LE(read_data_.data_len, buf_len);
read_buf_ = buf;
read_buf_len_ = buf_len;
read_callback_ = callback;
if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) {
read_pending_ = true;
DCHECK(read_callback_);
return ERR_IO_PENDING;
}
was_used_to_convey_data_ = true;
return CompleteRead();
}
// TODO(erikchen): Support connect sequencing.
int DeterministicMockTCPClientSocket::Connect(
net::CompletionCallback* callback) {
if (connected_)
return net::OK;
connected_ = true;
if (data_->connect_data().async) {
RunCallbackAsync(callback, data_->connect_data().result);
return net::ERR_IO_PENDING;
}
return data_->connect_data().result;
}
void DeterministicMockTCPClientSocket::Disconnect() {
MockClientSocket::Disconnect();
}
bool DeterministicMockTCPClientSocket::IsConnected() const {
return connected_;
}
bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
return IsConnected();
}
bool DeterministicMockTCPClientSocket::WasEverUsed() const {
return was_used_to_convey_data_;
}
bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
return false;
}
void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
class MockSSLClientSocket::ConnectCallback
: public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
public:
ConnectCallback(MockSSLClientSocket *ssl_client_socket,
net::CompletionCallback* user_callback,
int rv)
: ALLOW_THIS_IN_INITIALIZER_LIST(
net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
this, &ConnectCallback::Wrapper)),
ssl_client_socket_(ssl_client_socket),
user_callback_(user_callback),
rv_(rv) {
}
private:
void Wrapper(int rv) {
if (rv_ == net::OK)
ssl_client_socket_->connected_ = true;
user_callback_->Run(rv_);
delete this;
}
MockSSLClientSocket* ssl_client_socket_;
net::CompletionCallback* user_callback_;
int rv_;
};
MockSSLClientSocket::MockSSLClientSocket(
net::ClientSocketHandle* transport_socket,
const HostPortPair& host_port_pair,
const net::SSLConfig& ssl_config,
SSLHostInfo* ssl_host_info,
net::SSLSocketDataProvider* data)
: MockClientSocket(transport_socket->socket()->NetLog().net_log()),
transport_(transport_socket),
data_(data),
is_npn_state_set_(false),
new_npn_value_(false) {
DCHECK(data_);
delete ssl_host_info; // we take ownership but don't use it.
}
MockSSLClientSocket::~MockSSLClientSocket() {
Disconnect();
}
int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
return transport_->socket()->Read(buf, buf_len, callback);
}
int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
net::CompletionCallback* callback) {
return transport_->socket()->Write(buf, buf_len, callback);
}
int MockSSLClientSocket::Connect(net::CompletionCallback* callback) {
ConnectCallback* connect_callback = new ConnectCallback(
this, callback, data_->connect.result);
int rv = transport_->socket()->Connect(connect_callback);
if (rv == net::OK) {
delete connect_callback;
if (data_->connect.result == net::OK)
connected_ = true;
if (data_->connect.async) {
RunCallbackAsync(callback, data_->connect.result);
return net::ERR_IO_PENDING;
}
return data_->connect.result;
}
return rv;
}
void MockSSLClientSocket::Disconnect() {
MockClientSocket::Disconnect();
if (transport_->socket() != NULL)
transport_->socket()->Disconnect();
}
bool MockSSLClientSocket::IsConnected() const {
return transport_->socket()->IsConnected();
}
bool MockSSLClientSocket::WasEverUsed() const {
return transport_->socket()->WasEverUsed();
}
bool MockSSLClientSocket::UsingTCPFastOpen() const {
return transport_->socket()->UsingTCPFastOpen();
}
void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
ssl_info->Reset();
ssl_info->cert = data_->cert_;
}
void MockSSLClientSocket::GetSSLCertRequestInfo(
net::SSLCertRequestInfo* cert_request_info) {
DCHECK(cert_request_info);
if (data_->cert_request_info) {
cert_request_info->host_and_port =
data_->cert_request_info->host_and_port;
cert_request_info->client_certs = data_->cert_request_info->client_certs;
} else {
cert_request_info->Reset();
}
}
SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
std::string* proto) {
*proto = data_->next_proto;
return data_->next_proto_status;
}
bool MockSSLClientSocket::was_npn_negotiated() const {
if (is_npn_state_set_)
return new_npn_value_;
return data_->was_npn_negotiated;
}
bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
is_npn_state_set_ = true;
return new_npn_value_ = negotiated;
}
void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
NOTIMPLEMENTED();
}
TestSocketRequest::TestSocketRequest(
std::vector<TestSocketRequest*>* request_order,
size_t* completion_count)
: request_order_(request_order),
completion_count_(completion_count) {
DCHECK(request_order);
DCHECK(completion_count);
}
TestSocketRequest::~TestSocketRequest() {
}
int TestSocketRequest::WaitForResult() {
return callback_.WaitForResult();
}
void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
callback_.RunWithParams(params);
(*completion_count_)++;
request_order_->push_back(this);
}
// static
const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
// static
const int ClientSocketPoolTest::kRequestNotFound = -2;
ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
ClientSocketPoolTest::~ClientSocketPoolTest() {}
int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
index--;
if (index >= requests_.size())
return kIndexOutOfBounds;
for (size_t i = 0; i < request_order_.size(); i++)
if (requests_[index] == request_order_[i])
return i + 1;
return kRequestNotFound;
}
bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
ScopedVector<TestSocketRequest>::iterator i;
for (i = requests_.begin(); i != requests_.end(); ++i) {
if ((*i)->handle()->is_initialized()) {
if (keep_alive == NO_KEEP_ALIVE)
(*i)->handle()->socket()->Disconnect();
(*i)->handle()->Reset();
MessageLoop::current()->RunAllPending();
return true;
}
}
return false;
}
void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
bool released_one;
do {
released_one = ReleaseOneConnection(keep_alive);
} while (released_one);
}
MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
ClientSocket* socket,
ClientSocketHandle* handle,
CompletionCallback* callback)
: socket_(socket),
handle_(handle),
user_callback_(callback),
ALLOW_THIS_IN_INITIALIZER_LIST(
connect_callback_(this, &MockConnectJob::OnConnect)) {
}
MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
int MockTransportClientSocketPool::MockConnectJob::Connect() {
int rv = socket_->Connect(&connect_callback_);
if (rv == OK) {
user_callback_ = NULL;
OnConnect(OK);
}
return rv;
}
bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
const ClientSocketHandle* handle) {
if (handle != handle_)
return false;
socket_.reset();
handle_ = NULL;
user_callback_ = NULL;
return true;
}
void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
if (!socket_.get())
return;
if (rv == OK) {
handle_->set_socket(socket_.release());
} else {
socket_.reset();
}
handle_ = NULL;
if (user_callback_) {
CompletionCallback* callback = user_callback_;
user_callback_ = NULL;
callback->Run(rv);
}
}
MockTransportClientSocketPool::MockTransportClientSocketPool(
int max_sockets,
int max_sockets_per_group,
ClientSocketPoolHistograms* histograms,
ClientSocketFactory* socket_factory)
: TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
NULL, NULL, NULL),
client_socket_factory_(socket_factory),
release_count_(0),
cancel_count_(0) {
}
MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
int MockTransportClientSocketPool::RequestSocket(const std::string& group_name,
const void* socket_params,
RequestPriority priority,
ClientSocketHandle* handle,
CompletionCallback* callback,
const BoundNetLog& net_log) {
ClientSocket* socket = client_socket_factory_->CreateTransportClientSocket(
AddressList(), net_log.net_log(), net::NetLog::Source());
MockConnectJob* job = new MockConnectJob(socket, handle, callback);
job_list_.push_back(job);
handle->set_pool_id(1);
return job->Connect();
}
void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) {
std::vector<MockConnectJob*>::iterator i;
for (i = job_list_.begin(); i != job_list_.end(); ++i) {
if ((*i)->CancelHandle(handle)) {
cancel_count_++;
break;
}
}
}
void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name,
ClientSocket* socket, int id) {
EXPECT_EQ(1, id);
release_count_++;
delete socket;
}
DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
void DeterministicMockClientSocketFactory::AddSocketDataProvider(
DeterministicSocketData* data) {
mock_data_.Add(data);
}
void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
SSLSocketDataProvider* data) {
mock_ssl_data_.Add(data);
}
void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
mock_data_.ResetNextIndex();
mock_ssl_data_.ResetNextIndex();
}
MockSSLClientSocket* DeterministicMockClientSocketFactory::
GetMockSSLClientSocket(size_t index) const {
DCHECK_LT(index, ssl_client_sockets_.size());
return ssl_client_sockets_[index];
}
ClientSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const net::NetLog::Source& source) {
DeterministicSocketData* data_provider = mock_data().GetNext();
DeterministicMockTCPClientSocket* socket =
new DeterministicMockTCPClientSocket(net_log, data_provider);
data_provider->set_socket(socket->AsWeakPtr());
tcp_client_sockets().push_back(socket);
return socket;
}
SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket(
ClientSocketHandle* transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLHostInfo* ssl_host_info,
CertVerifier* cert_verifier,
DnsCertProvenanceChecker* dns_cert_checker) {
MockSSLClientSocket* socket =
new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
ssl_host_info, mock_ssl_data_.GetNext());
ssl_client_sockets_.push_back(socket);
return socket;
}
void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
}
MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
int max_sockets,
int max_sockets_per_group,
ClientSocketPoolHistograms* histograms,
TransportClientSocketPool* transport_pool)
: SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
NULL, transport_pool, NULL),
transport_pool_(transport_pool) {
}
MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name,
const void* socket_params,
RequestPriority priority,
ClientSocketHandle* handle,
CompletionCallback* callback,
const BoundNetLog& net_log) {
return transport_pool_->RequestSocket(group_name,
socket_params,
priority,
handle,
callback,
net_log);
}
void MockSOCKSClientSocketPool::CancelRequest(
const std::string& group_name,
ClientSocketHandle* handle) {
return transport_pool_->CancelRequest(group_name, handle);
}
void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
ClientSocket* socket, int id) {
return transport_pool_->ReleaseSocket(group_name, socket, id);
}
const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
const char kSOCKS5OkRequest[] =
{ 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
const char kSOCKS5OkResponse[] =
{ 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
} // namespace net