// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "jingle/glue/chrome_async_socket.h" #include <algorithm> #include <cstdlib> #include <cstring> #include "base/basictypes.h" #include "base/bind.h" #include "base/compiler_specific.h" #include "base/logging.h" #include "base/message_loop/message_loop.h" #include "jingle/glue/resolving_client_socket_factory.h" #include "net/base/address_list.h" #include "net/base/host_port_pair.h" #include "net/base/io_buffer.h" #include "net/base/net_util.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/tcp_client_socket.h" #include "net/ssl/ssl_config_service.h" #include "third_party/libjingle/source/talk/base/socketaddress.h" namespace jingle_glue { ChromeAsyncSocket::ChromeAsyncSocket( ResolvingClientSocketFactory* resolving_client_socket_factory, size_t read_buf_size, size_t write_buf_size) : resolving_client_socket_factory_(resolving_client_socket_factory), state_(STATE_CLOSED), error_(ERROR_NONE), net_error_(net::OK), read_state_(IDLE), read_buf_(new net::IOBufferWithSize(read_buf_size)), read_start_(0U), read_end_(0U), write_state_(IDLE), write_buf_(new net::IOBufferWithSize(write_buf_size)), write_end_(0U), weak_ptr_factory_(this) { DCHECK(resolving_client_socket_factory_.get()); DCHECK_GT(read_buf_size, 0U); DCHECK_GT(write_buf_size, 0U); } ChromeAsyncSocket::~ChromeAsyncSocket() {} ChromeAsyncSocket::State ChromeAsyncSocket::state() { return state_; } ChromeAsyncSocket::Error ChromeAsyncSocket::error() { return error_; } int ChromeAsyncSocket::GetError() { return net_error_; } bool ChromeAsyncSocket::IsOpen() const { return (state_ == STATE_OPEN) || (state_ == STATE_TLS_OPEN); } void ChromeAsyncSocket::DoNonNetError(Error error) { DCHECK_NE(error, ERROR_NONE); DCHECK_NE(error, ERROR_WINSOCK); error_ = error; net_error_ = net::OK; } void ChromeAsyncSocket::DoNetError(net::Error net_error) { error_ = ERROR_WINSOCK; net_error_ = net_error; } void ChromeAsyncSocket::DoNetErrorFromStatus(int status) { DCHECK_LT(status, net::OK); DoNetError(static_cast<net::Error>(status)); } // STATE_CLOSED -> STATE_CONNECTING bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) { if (state_ != STATE_CLOSED) { LOG(DFATAL) << "Connect() called on non-closed socket"; DoNonNetError(ERROR_WRONGSTATE); return false; } if (address.hostname().empty() || address.port() == 0) { DoNonNetError(ERROR_DNS); return false; } DCHECK_EQ(state_, buzz::AsyncSocket::STATE_CLOSED); DCHECK_EQ(read_state_, IDLE); DCHECK_EQ(write_state_, IDLE); state_ = STATE_CONNECTING; DCHECK(!weak_ptr_factory_.HasWeakPtrs()); weak_ptr_factory_.InvalidateWeakPtrs(); net::HostPortPair dest_host_port_pair(address.hostname(), address.port()); transport_socket_ = resolving_client_socket_factory_->CreateTransportClientSocket( dest_host_port_pair); int status = transport_socket_->Connect( base::Bind(&ChromeAsyncSocket::ProcessConnectDone, weak_ptr_factory_.GetWeakPtr())); if (status != net::ERR_IO_PENDING) { // We defer execution of ProcessConnectDone instead of calling it // directly here as the caller may not expect an error/close to // happen here. This is okay, as from the caller's point of view, // the connect always happens asynchronously. base::MessageLoop* message_loop = base::MessageLoop::current(); CHECK(message_loop); message_loop->PostTask( FROM_HERE, base::Bind(&ChromeAsyncSocket::ProcessConnectDone, weak_ptr_factory_.GetWeakPtr(), status)); } return true; } // STATE_CONNECTING -> STATE_OPEN // read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead()) void ChromeAsyncSocket::ProcessConnectDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); DCHECK_EQ(read_state_, IDLE); DCHECK_EQ(write_state_, IDLE); DCHECK_EQ(state_, STATE_CONNECTING); if (status != net::OK) { DoNetErrorFromStatus(status); DoClose(); return; } state_ = STATE_OPEN; PostDoRead(); // Write buffer should be empty. DCHECK_EQ(write_end_, 0U); SignalConnected(); } // read_state_ == IDLE -> read_state_ == POSTED void ChromeAsyncSocket::PostDoRead() { DCHECK(IsOpen()); DCHECK_EQ(read_state_, IDLE); DCHECK_EQ(read_start_, 0U); DCHECK_EQ(read_end_, 0U); base::MessageLoop* message_loop = base::MessageLoop::current(); CHECK(message_loop); message_loop->PostTask( FROM_HERE, base::Bind(&ChromeAsyncSocket::DoRead, weak_ptr_factory_.GetWeakPtr())); read_state_ = POSTED; } // read_state_ == POSTED -> read_state_ == PENDING void ChromeAsyncSocket::DoRead() { DCHECK(IsOpen()); DCHECK_EQ(read_state_, POSTED); DCHECK_EQ(read_start_, 0U); DCHECK_EQ(read_end_, 0U); // Once we call Read(), we cannot call StartTls() until the read // finishes. This is okay, as StartTls() is called only from a read // handler (i.e., after a read finishes and before another read is // done). int status = transport_socket_->Read( read_buf_.get(), read_buf_->size(), base::Bind(&ChromeAsyncSocket::ProcessReadDone, weak_ptr_factory_.GetWeakPtr())); read_state_ = PENDING; if (status != net::ERR_IO_PENDING) { ProcessReadDone(status); } } // read_state_ == PENDING -> read_state_ == IDLE void ChromeAsyncSocket::ProcessReadDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); DCHECK(IsOpen()); DCHECK_EQ(read_state_, PENDING); DCHECK_EQ(read_start_, 0U); DCHECK_EQ(read_end_, 0U); read_state_ = IDLE; if (status > 0) { read_end_ = static_cast<size_t>(status); SignalRead(); } else if (status == 0) { // Other side closed the connection. error_ = ERROR_NONE; net_error_ = net::OK; DoClose(); } else { // status < 0 DoNetErrorFromStatus(status); DoClose(); } } // (maybe) read_state_ == IDLE -> read_state_ == POSTED (via // PostDoRead()) bool ChromeAsyncSocket::Read(char* data, size_t len, size_t* len_read) { if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) { // Read() may be called on a closed socket if a previous read // causes a socket close (e.g., client sends wrong password and // server terminates connection). // // TODO(akalin): Fix handling of this on the libjingle side. if (state_ != STATE_CLOSED) { LOG(DFATAL) << "Read() called on non-open non-tls-connecting socket"; } DoNonNetError(ERROR_WRONGSTATE); return false; } DCHECK_LE(read_start_, read_end_); if ((state_ == STATE_TLS_CONNECTING) || read_end_ == 0U) { if (state_ == STATE_TLS_CONNECTING) { DCHECK_EQ(read_state_, IDLE); DCHECK_EQ(read_end_, 0U); } else { DCHECK_NE(read_state_, IDLE); } *len_read = 0; return true; } DCHECK_EQ(read_state_, IDLE); *len_read = std::min(len, read_end_ - read_start_); DCHECK_GT(*len_read, 0U); std::memcpy(data, read_buf_->data() + read_start_, *len_read); read_start_ += *len_read; if (read_start_ == read_end_) { read_start_ = 0U; read_end_ = 0U; // We defer execution of DoRead() here for similar reasons as // ProcessConnectDone(). PostDoRead(); } return true; } // (maybe) write_state_ == IDLE -> write_state_ == POSTED (via // PostDoWrite()) bool ChromeAsyncSocket::Write(const char* data, size_t len) { if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) { LOG(DFATAL) << "Write() called on non-open non-tls-connecting socket"; DoNonNetError(ERROR_WRONGSTATE); return false; } // TODO(akalin): Avoid this check by modifying the interface to have // a "ready for writing" signal. if ((static_cast<size_t>(write_buf_->size()) - write_end_) < len) { LOG(DFATAL) << "queueing " << len << " bytes would exceed the " << "max write buffer size = " << write_buf_->size() << " by " << (len - write_buf_->size()) << " bytes"; DoNetError(net::ERR_INSUFFICIENT_RESOURCES); return false; } std::memcpy(write_buf_->data() + write_end_, data, len); write_end_ += len; // If we're TLS-connecting, the write buffer will get flushed once // the TLS-connect finishes. Otherwise, start writing if we're not // already writing and we have something to write. if ((state_ != STATE_TLS_CONNECTING) && (write_state_ == IDLE) && (write_end_ > 0U)) { // We defer execution of DoWrite() here for similar reasons as // ProcessConnectDone(). PostDoWrite(); } return true; } // write_state_ == IDLE -> write_state_ == POSTED void ChromeAsyncSocket::PostDoWrite() { DCHECK(IsOpen()); DCHECK_EQ(write_state_, IDLE); DCHECK_GT(write_end_, 0U); base::MessageLoop* message_loop = base::MessageLoop::current(); CHECK(message_loop); message_loop->PostTask( FROM_HERE, base::Bind(&ChromeAsyncSocket::DoWrite, weak_ptr_factory_.GetWeakPtr())); write_state_ = POSTED; } // write_state_ == POSTED -> write_state_ == PENDING void ChromeAsyncSocket::DoWrite() { DCHECK(IsOpen()); DCHECK_EQ(write_state_, POSTED); DCHECK_GT(write_end_, 0U); // Once we call Write(), we cannot call StartTls() until the write // finishes. This is okay, as StartTls() is called only after we // have received a reply to a message we sent to the server and // before we send the next message. int status = transport_socket_->Write( write_buf_.get(), write_end_, base::Bind(&ChromeAsyncSocket::ProcessWriteDone, weak_ptr_factory_.GetWeakPtr())); write_state_ = PENDING; if (status != net::ERR_IO_PENDING) { ProcessWriteDone(status); } } // write_state_ == PENDING -> write_state_ == IDLE or POSTED (the // latter via PostDoWrite()) void ChromeAsyncSocket::ProcessWriteDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); DCHECK(IsOpen()); DCHECK_EQ(write_state_, PENDING); DCHECK_GT(write_end_, 0U); write_state_ = IDLE; if (status < net::OK) { DoNetErrorFromStatus(status); DoClose(); return; } size_t written = static_cast<size_t>(status); if (written > write_end_) { LOG(DFATAL) << "bytes written = " << written << " exceeds bytes requested = " << write_end_; DoNetError(net::ERR_UNEXPECTED); DoClose(); return; } // TODO(akalin): Figure out a better way to do this; perhaps a queue // of DrainableIOBuffers. This'll also allow us to not have an // artificial buffer size limit. std::memmove(write_buf_->data(), write_buf_->data() + written, write_end_ - written); write_end_ -= written; if (write_end_ > 0U) { PostDoWrite(); } } // * -> STATE_CLOSED bool ChromeAsyncSocket::Close() { DoClose(); return true; } // (not STATE_CLOSED) -> STATE_CLOSED void ChromeAsyncSocket::DoClose() { weak_ptr_factory_.InvalidateWeakPtrs(); if (transport_socket_.get()) { transport_socket_->Disconnect(); } transport_socket_.reset(); read_state_ = IDLE; read_start_ = 0U; read_end_ = 0U; write_state_ = IDLE; write_end_ = 0U; if (state_ != STATE_CLOSED) { state_ = STATE_CLOSED; SignalClosed(); } // Reset error variables after SignalClosed() so slots connected // to it can read it. error_ = ERROR_NONE; net_error_ = net::OK; } // STATE_OPEN -> STATE_TLS_CONNECTING bool ChromeAsyncSocket::StartTls(const std::string& domain_name) { if ((state_ != STATE_OPEN) || (read_state_ == PENDING) || (write_state_ != IDLE)) { LOG(DFATAL) << "StartTls() called in wrong state"; DoNonNetError(ERROR_WRONGSTATE); return false; } state_ = STATE_TLS_CONNECTING; read_state_ = IDLE; read_start_ = 0U; read_end_ = 0U; DCHECK_EQ(write_end_, 0U); // Clear out any posted DoRead() tasks. weak_ptr_factory_.InvalidateWeakPtrs(); DCHECK(transport_socket_.get()); scoped_ptr<net::ClientSocketHandle> socket_handle( new net::ClientSocketHandle()); socket_handle->SetSocket(transport_socket_.Pass()); transport_socket_ = resolving_client_socket_factory_->CreateSSLClientSocket( socket_handle.Pass(), net::HostPortPair(domain_name, 443)); int status = transport_socket_->Connect( base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone, weak_ptr_factory_.GetWeakPtr())); if (status != net::ERR_IO_PENDING) { base::MessageLoop* message_loop = base::MessageLoop::current(); CHECK(message_loop); message_loop->PostTask( FROM_HERE, base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone, weak_ptr_factory_.GetWeakPtr(), status)); } return true; } // STATE_TLS_CONNECTING -> STATE_TLS_OPEN // read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead()) // (maybe) write_state_ == IDLE -> write_state_ == POSTED (via // PostDoWrite()) void ChromeAsyncSocket::ProcessSSLConnectDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); DCHECK_EQ(state_, STATE_TLS_CONNECTING); DCHECK_EQ(read_state_, IDLE); DCHECK_EQ(read_start_, 0U); DCHECK_EQ(read_end_, 0U); DCHECK_EQ(write_state_, IDLE); if (status != net::OK) { DoNetErrorFromStatus(status); DoClose(); return; } state_ = STATE_TLS_OPEN; PostDoRead(); if (write_end_ > 0U) { PostDoWrite(); } SignalSSLConnected(); } } // namespace jingle_glue