// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "jingle/glue/chrome_async_socket.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include "base/basictypes.h"
#include "base/bind.h"
#include "base/compiler_specific.h"
#include "base/logging.h"
#include "base/message_loop/message_loop.h"
#include "jingle/glue/resolving_client_socket_factory.h"
#include "net/base/address_list.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/net_util.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/tcp_client_socket.h"
#include "net/ssl/ssl_config_service.h"
#include "third_party/libjingle/source/talk/base/socketaddress.h"
namespace jingle_glue {
ChromeAsyncSocket::ChromeAsyncSocket(
ResolvingClientSocketFactory* resolving_client_socket_factory,
size_t read_buf_size,
size_t write_buf_size)
: resolving_client_socket_factory_(resolving_client_socket_factory),
state_(STATE_CLOSED),
error_(ERROR_NONE),
net_error_(net::OK),
read_state_(IDLE),
read_buf_(new net::IOBufferWithSize(read_buf_size)),
read_start_(0U),
read_end_(0U),
write_state_(IDLE),
write_buf_(new net::IOBufferWithSize(write_buf_size)),
write_end_(0U),
weak_ptr_factory_(this) {
DCHECK(resolving_client_socket_factory_.get());
DCHECK_GT(read_buf_size, 0U);
DCHECK_GT(write_buf_size, 0U);
}
ChromeAsyncSocket::~ChromeAsyncSocket() {}
ChromeAsyncSocket::State ChromeAsyncSocket::state() {
return state_;
}
ChromeAsyncSocket::Error ChromeAsyncSocket::error() {
return error_;
}
int ChromeAsyncSocket::GetError() {
return net_error_;
}
bool ChromeAsyncSocket::IsOpen() const {
return (state_ == STATE_OPEN) || (state_ == STATE_TLS_OPEN);
}
void ChromeAsyncSocket::DoNonNetError(Error error) {
DCHECK_NE(error, ERROR_NONE);
DCHECK_NE(error, ERROR_WINSOCK);
error_ = error;
net_error_ = net::OK;
}
void ChromeAsyncSocket::DoNetError(net::Error net_error) {
error_ = ERROR_WINSOCK;
net_error_ = net_error;
}
void ChromeAsyncSocket::DoNetErrorFromStatus(int status) {
DCHECK_LT(status, net::OK);
DoNetError(static_cast<net::Error>(status));
}
// STATE_CLOSED -> STATE_CONNECTING
bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) {
if (state_ != STATE_CLOSED) {
LOG(DFATAL) << "Connect() called on non-closed socket";
DoNonNetError(ERROR_WRONGSTATE);
return false;
}
if (address.hostname().empty() || address.port() == 0) {
DoNonNetError(ERROR_DNS);
return false;
}
DCHECK_EQ(state_, buzz::AsyncSocket::STATE_CLOSED);
DCHECK_EQ(read_state_, IDLE);
DCHECK_EQ(write_state_, IDLE);
state_ = STATE_CONNECTING;
DCHECK(!weak_ptr_factory_.HasWeakPtrs());
weak_ptr_factory_.InvalidateWeakPtrs();
net::HostPortPair dest_host_port_pair(address.hostname(), address.port());
transport_socket_ =
resolving_client_socket_factory_->CreateTransportClientSocket(
dest_host_port_pair);
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
weak_ptr_factory_.GetWeakPtr()));
if (status != net::ERR_IO_PENDING) {
// We defer execution of ProcessConnectDone instead of calling it
// directly here as the caller may not expect an error/close to
// happen here. This is okay, as from the caller's point of view,
// the connect always happens asynchronously.
base::MessageLoop* message_loop = base::MessageLoop::current();
CHECK(message_loop);
message_loop->PostTask(
FROM_HERE,
base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
weak_ptr_factory_.GetWeakPtr(), status));
}
return true;
}
// STATE_CONNECTING -> STATE_OPEN
// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead())
void ChromeAsyncSocket::ProcessConnectDone(int status) {
DCHECK_NE(status, net::ERR_IO_PENDING);
DCHECK_EQ(read_state_, IDLE);
DCHECK_EQ(write_state_, IDLE);
DCHECK_EQ(state_, STATE_CONNECTING);
if (status != net::OK) {
DoNetErrorFromStatus(status);
DoClose();
return;
}
state_ = STATE_OPEN;
PostDoRead();
// Write buffer should be empty.
DCHECK_EQ(write_end_, 0U);
SignalConnected();
}
// read_state_ == IDLE -> read_state_ == POSTED
void ChromeAsyncSocket::PostDoRead() {
DCHECK(IsOpen());
DCHECK_EQ(read_state_, IDLE);
DCHECK_EQ(read_start_, 0U);
DCHECK_EQ(read_end_, 0U);
base::MessageLoop* message_loop = base::MessageLoop::current();
CHECK(message_loop);
message_loop->PostTask(
FROM_HERE,
base::Bind(&ChromeAsyncSocket::DoRead,
weak_ptr_factory_.GetWeakPtr()));
read_state_ = POSTED;
}
// read_state_ == POSTED -> read_state_ == PENDING
void ChromeAsyncSocket::DoRead() {
DCHECK(IsOpen());
DCHECK_EQ(read_state_, POSTED);
DCHECK_EQ(read_start_, 0U);
DCHECK_EQ(read_end_, 0U);
// Once we call Read(), we cannot call StartTls() until the read
// finishes. This is okay, as StartTls() is called only from a read
// handler (i.e., after a read finishes and before another read is
// done).
int status =
transport_socket_->Read(
read_buf_.get(), read_buf_->size(),
base::Bind(&ChromeAsyncSocket::ProcessReadDone,
weak_ptr_factory_.GetWeakPtr()));
read_state_ = PENDING;
if (status != net::ERR_IO_PENDING) {
ProcessReadDone(status);
}
}
// read_state_ == PENDING -> read_state_ == IDLE
void ChromeAsyncSocket::ProcessReadDone(int status) {
DCHECK_NE(status, net::ERR_IO_PENDING);
DCHECK(IsOpen());
DCHECK_EQ(read_state_, PENDING);
DCHECK_EQ(read_start_, 0U);
DCHECK_EQ(read_end_, 0U);
read_state_ = IDLE;
if (status > 0) {
read_end_ = static_cast<size_t>(status);
SignalRead();
} else if (status == 0) {
// Other side closed the connection.
error_ = ERROR_NONE;
net_error_ = net::OK;
DoClose();
} else { // status < 0
DoNetErrorFromStatus(status);
DoClose();
}
}
// (maybe) read_state_ == IDLE -> read_state_ == POSTED (via
// PostDoRead())
bool ChromeAsyncSocket::Read(char* data, size_t len, size_t* len_read) {
if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) {
// Read() may be called on a closed socket if a previous read
// causes a socket close (e.g., client sends wrong password and
// server terminates connection).
//
// TODO(akalin): Fix handling of this on the libjingle side.
if (state_ != STATE_CLOSED) {
LOG(DFATAL) << "Read() called on non-open non-tls-connecting socket";
}
DoNonNetError(ERROR_WRONGSTATE);
return false;
}
DCHECK_LE(read_start_, read_end_);
if ((state_ == STATE_TLS_CONNECTING) || read_end_ == 0U) {
if (state_ == STATE_TLS_CONNECTING) {
DCHECK_EQ(read_state_, IDLE);
DCHECK_EQ(read_end_, 0U);
} else {
DCHECK_NE(read_state_, IDLE);
}
*len_read = 0;
return true;
}
DCHECK_EQ(read_state_, IDLE);
*len_read = std::min(len, read_end_ - read_start_);
DCHECK_GT(*len_read, 0U);
std::memcpy(data, read_buf_->data() + read_start_, *len_read);
read_start_ += *len_read;
if (read_start_ == read_end_) {
read_start_ = 0U;
read_end_ = 0U;
// We defer execution of DoRead() here for similar reasons as
// ProcessConnectDone().
PostDoRead();
}
return true;
}
// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via
// PostDoWrite())
bool ChromeAsyncSocket::Write(const char* data, size_t len) {
if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) {
LOG(DFATAL) << "Write() called on non-open non-tls-connecting socket";
DoNonNetError(ERROR_WRONGSTATE);
return false;
}
// TODO(akalin): Avoid this check by modifying the interface to have
// a "ready for writing" signal.
if ((static_cast<size_t>(write_buf_->size()) - write_end_) < len) {
LOG(DFATAL) << "queueing " << len << " bytes would exceed the "
<< "max write buffer size = " << write_buf_->size()
<< " by " << (len - write_buf_->size()) << " bytes";
DoNetError(net::ERR_INSUFFICIENT_RESOURCES);
return false;
}
std::memcpy(write_buf_->data() + write_end_, data, len);
write_end_ += len;
// If we're TLS-connecting, the write buffer will get flushed once
// the TLS-connect finishes. Otherwise, start writing if we're not
// already writing and we have something to write.
if ((state_ != STATE_TLS_CONNECTING) &&
(write_state_ == IDLE) && (write_end_ > 0U)) {
// We defer execution of DoWrite() here for similar reasons as
// ProcessConnectDone().
PostDoWrite();
}
return true;
}
// write_state_ == IDLE -> write_state_ == POSTED
void ChromeAsyncSocket::PostDoWrite() {
DCHECK(IsOpen());
DCHECK_EQ(write_state_, IDLE);
DCHECK_GT(write_end_, 0U);
base::MessageLoop* message_loop = base::MessageLoop::current();
CHECK(message_loop);
message_loop->PostTask(
FROM_HERE,
base::Bind(&ChromeAsyncSocket::DoWrite,
weak_ptr_factory_.GetWeakPtr()));
write_state_ = POSTED;
}
// write_state_ == POSTED -> write_state_ == PENDING
void ChromeAsyncSocket::DoWrite() {
DCHECK(IsOpen());
DCHECK_EQ(write_state_, POSTED);
DCHECK_GT(write_end_, 0U);
// Once we call Write(), we cannot call StartTls() until the write
// finishes. This is okay, as StartTls() is called only after we
// have received a reply to a message we sent to the server and
// before we send the next message.
int status =
transport_socket_->Write(
write_buf_.get(), write_end_,
base::Bind(&ChromeAsyncSocket::ProcessWriteDone,
weak_ptr_factory_.GetWeakPtr()));
write_state_ = PENDING;
if (status != net::ERR_IO_PENDING) {
ProcessWriteDone(status);
}
}
// write_state_ == PENDING -> write_state_ == IDLE or POSTED (the
// latter via PostDoWrite())
void ChromeAsyncSocket::ProcessWriteDone(int status) {
DCHECK_NE(status, net::ERR_IO_PENDING);
DCHECK(IsOpen());
DCHECK_EQ(write_state_, PENDING);
DCHECK_GT(write_end_, 0U);
write_state_ = IDLE;
if (status < net::OK) {
DoNetErrorFromStatus(status);
DoClose();
return;
}
size_t written = static_cast<size_t>(status);
if (written > write_end_) {
LOG(DFATAL) << "bytes written = " << written
<< " exceeds bytes requested = " << write_end_;
DoNetError(net::ERR_UNEXPECTED);
DoClose();
return;
}
// TODO(akalin): Figure out a better way to do this; perhaps a queue
// of DrainableIOBuffers. This'll also allow us to not have an
// artificial buffer size limit.
std::memmove(write_buf_->data(),
write_buf_->data() + written,
write_end_ - written);
write_end_ -= written;
if (write_end_ > 0U) {
PostDoWrite();
}
}
// * -> STATE_CLOSED
bool ChromeAsyncSocket::Close() {
DoClose();
return true;
}
// (not STATE_CLOSED) -> STATE_CLOSED
void ChromeAsyncSocket::DoClose() {
weak_ptr_factory_.InvalidateWeakPtrs();
if (transport_socket_.get()) {
transport_socket_->Disconnect();
}
transport_socket_.reset();
read_state_ = IDLE;
read_start_ = 0U;
read_end_ = 0U;
write_state_ = IDLE;
write_end_ = 0U;
if (state_ != STATE_CLOSED) {
state_ = STATE_CLOSED;
SignalClosed();
}
// Reset error variables after SignalClosed() so slots connected
// to it can read it.
error_ = ERROR_NONE;
net_error_ = net::OK;
}
// STATE_OPEN -> STATE_TLS_CONNECTING
bool ChromeAsyncSocket::StartTls(const std::string& domain_name) {
if ((state_ != STATE_OPEN) || (read_state_ == PENDING) ||
(write_state_ != IDLE)) {
LOG(DFATAL) << "StartTls() called in wrong state";
DoNonNetError(ERROR_WRONGSTATE);
return false;
}
state_ = STATE_TLS_CONNECTING;
read_state_ = IDLE;
read_start_ = 0U;
read_end_ = 0U;
DCHECK_EQ(write_end_, 0U);
// Clear out any posted DoRead() tasks.
weak_ptr_factory_.InvalidateWeakPtrs();
DCHECK(transport_socket_.get());
scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle());
socket_handle->SetSocket(transport_socket_.Pass());
transport_socket_ =
resolving_client_socket_factory_->CreateSSLClientSocket(
socket_handle.Pass(), net::HostPortPair(domain_name, 443));
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
weak_ptr_factory_.GetWeakPtr()));
if (status != net::ERR_IO_PENDING) {
base::MessageLoop* message_loop = base::MessageLoop::current();
CHECK(message_loop);
message_loop->PostTask(
FROM_HERE,
base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
weak_ptr_factory_.GetWeakPtr(), status));
}
return true;
}
// STATE_TLS_CONNECTING -> STATE_TLS_OPEN
// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead())
// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via
// PostDoWrite())
void ChromeAsyncSocket::ProcessSSLConnectDone(int status) {
DCHECK_NE(status, net::ERR_IO_PENDING);
DCHECK_EQ(state_, STATE_TLS_CONNECTING);
DCHECK_EQ(read_state_, IDLE);
DCHECK_EQ(read_start_, 0U);
DCHECK_EQ(read_end_, 0U);
DCHECK_EQ(write_state_, IDLE);
if (status != net::OK) {
DoNetErrorFromStatus(status);
DoClose();
return;
}
state_ = STATE_TLS_OPEN;
PostDoRead();
if (write_end_ > 0U) {
PostDoWrite();
}
SignalSSLConnected();
}
} // namespace jingle_glue