// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/protocol/channel_multiplexer.h" #include <string.h> #include "base/bind.h" #include "base/callback.h" #include "base/location.h" #include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/thread_task_runner_handle.h" #include "net/base/net_errors.h" #include "net/socket/stream_socket.h" #include "remoting/protocol/message_serialization.h" namespace remoting { namespace protocol { namespace { const int kChannelIdUnknown = -1; const int kMaxPacketSize = 1024; class PendingPacket { public: PendingPacket(scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task) : packet(packet.Pass()), done_task(done_task), pos(0U) { } ~PendingPacket() { done_task.Run(); } bool is_empty() { return pos >= packet->data().size(); } int Read(char* buffer, size_t size) { size = std::min(size, packet->data().size() - pos); memcpy(buffer, packet->data().data() + pos, size); pos += size; return size; } private: scoped_ptr<MultiplexPacket> packet; base::Closure done_task; size_t pos; DISALLOW_COPY_AND_ASSIGN(PendingPacket); }; } // namespace const char ChannelMultiplexer::kMuxChannelName[] = "mux"; struct ChannelMultiplexer::PendingChannel { PendingChannel(const std::string& name, const ChannelCreatedCallback& callback) : name(name), callback(callback) { } std::string name; ChannelCreatedCallback callback; }; class ChannelMultiplexer::MuxChannel { public: MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, int send_id); ~MuxChannel(); const std::string& name() { return name_; } int receive_id() { return receive_id_; } void set_receive_id(int id) { receive_id_ = id; } // Called by ChannelMultiplexer. scoped_ptr<net::StreamSocket> CreateSocket(); void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task); void OnWriteFailed(); // Called by MuxSocket. void OnSocketDestroyed(); bool DoWrite(scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task); int DoRead(net::IOBuffer* buffer, int buffer_len); private: ChannelMultiplexer* multiplexer_; std::string name_; int send_id_; bool id_sent_; int receive_id_; MuxSocket* socket_; std::list<PendingPacket*> pending_packets_; DISALLOW_COPY_AND_ASSIGN(MuxChannel); }; class ChannelMultiplexer::MuxSocket : public net::StreamSocket, public base::NonThreadSafe, public base::SupportsWeakPtr<MuxSocket> { public: MuxSocket(MuxChannel* channel); virtual ~MuxSocket(); void OnWriteComplete(); void OnWriteFailed(); void OnPacketReceived(); // net::StreamSocket interface. virtual int Read(net::IOBuffer* buffer, int buffer_len, const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buffer, int buffer_len, const net::CompletionCallback& callback) OVERRIDE; virtual int SetReceiveBufferSize(int32 size) OVERRIDE { NOTIMPLEMENTED(); return net::ERR_NOT_IMPLEMENTED; } virtual int SetSendBufferSize(int32 size) OVERRIDE { NOTIMPLEMENTED(); return net::ERR_NOT_IMPLEMENTED; } virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { NOTIMPLEMENTED(); return net::ERR_NOT_IMPLEMENTED; } virtual void Disconnect() OVERRIDE { NOTIMPLEMENTED(); } virtual bool IsConnected() const OVERRIDE { NOTIMPLEMENTED(); return true; } virtual bool IsConnectedAndIdle() const OVERRIDE { NOTIMPLEMENTED(); return false; } virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { NOTIMPLEMENTED(); return net::ERR_NOT_IMPLEMENTED; } virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { NOTIMPLEMENTED(); return net::ERR_NOT_IMPLEMENTED; } virtual const net::BoundNetLog& NetLog() const OVERRIDE { NOTIMPLEMENTED(); return net_log_; } virtual void SetSubresourceSpeculation() OVERRIDE { NOTIMPLEMENTED(); } virtual void SetOmniboxSpeculation() OVERRIDE { NOTIMPLEMENTED(); } virtual bool WasEverUsed() const OVERRIDE { return true; } virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } virtual bool WasNpnNegotiated() const OVERRIDE { return false; } virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { return net::kProtoUnknown; } virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { NOTIMPLEMENTED(); return false; } private: MuxChannel* channel_; net::CompletionCallback read_callback_; scoped_refptr<net::IOBuffer> read_buffer_; int read_buffer_size_; bool write_pending_; int write_result_; net::CompletionCallback write_callback_; net::BoundNetLog net_log_; DISALLOW_COPY_AND_ASSIGN(MuxSocket); }; ChannelMultiplexer::MuxChannel::MuxChannel( ChannelMultiplexer* multiplexer, const std::string& name, int send_id) : multiplexer_(multiplexer), name_(name), send_id_(send_id), id_sent_(false), receive_id_(kChannelIdUnknown), socket_(NULL) { } ChannelMultiplexer::MuxChannel::~MuxChannel() { // Socket must be destroyed before the channel. DCHECK(!socket_); STLDeleteElements(&pending_packets_); } scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { DCHECK(!socket_); // Can't create more than one socket per channel. scoped_ptr<MuxSocket> result(new MuxSocket(this)); socket_ = result.get(); return result.PassAs<net::StreamSocket>(); } void ChannelMultiplexer::MuxChannel::OnIncomingPacket( scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task) { DCHECK_EQ(packet->channel_id(), receive_id_); if (packet->data().size() > 0) { pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); if (socket_) { // Notify the socket that we have more data. socket_->OnPacketReceived(); } } } void ChannelMultiplexer::MuxChannel::OnWriteFailed() { if (socket_) socket_->OnWriteFailed(); } void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { DCHECK(socket_); socket_ = NULL; } bool ChannelMultiplexer::MuxChannel::DoWrite( scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task) { packet->set_channel_id(send_id_); if (!id_sent_) { packet->set_channel_name(name_); id_sent_ = true; } return multiplexer_->DoWrite(packet.Pass(), done_task); } int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, int buffer_len) { int pos = 0; while (buffer_len > 0 && !pending_packets_.empty()) { DCHECK(!pending_packets_.front()->is_empty()); int result = pending_packets_.front()->Read( buffer->data() + pos, buffer_len); DCHECK_LE(result, buffer_len); pos += result; buffer_len -= pos; if (pending_packets_.front()->is_empty()) { delete pending_packets_.front(); pending_packets_.erase(pending_packets_.begin()); } } return pos; } ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) : channel_(channel), read_buffer_size_(0), write_pending_(false), write_result_(0) { } ChannelMultiplexer::MuxSocket::~MuxSocket() { channel_->OnSocketDestroyed(); } int ChannelMultiplexer::MuxSocket::Read( net::IOBuffer* buffer, int buffer_len, const net::CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK(read_callback_.is_null()); int result = channel_->DoRead(buffer, buffer_len); if (result == 0) { read_buffer_ = buffer; read_buffer_size_ = buffer_len; read_callback_ = callback; return net::ERR_IO_PENDING; } return result; } int ChannelMultiplexer::MuxSocket::Write( net::IOBuffer* buffer, int buffer_len, const net::CompletionCallback& callback) { DCHECK(CalledOnValidThread()); scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); size_t size = std::min(kMaxPacketSize, buffer_len); packet->mutable_data()->assign(buffer->data(), size); write_pending_ = true; bool result = channel_->DoWrite(packet.Pass(), base::Bind( &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); if (!result) { // Cannot complete the write, e.g. if the connection has been terminated. return net::ERR_FAILED; } // OnWriteComplete() might be called above synchronously. if (write_pending_) { DCHECK(write_callback_.is_null()); write_callback_ = callback; write_result_ = size; return net::ERR_IO_PENDING; } return size; } void ChannelMultiplexer::MuxSocket::OnWriteComplete() { write_pending_ = false; if (!write_callback_.is_null()) { net::CompletionCallback cb; std::swap(cb, write_callback_); cb.Run(write_result_); } } void ChannelMultiplexer::MuxSocket::OnWriteFailed() { if (!write_callback_.is_null()) { net::CompletionCallback cb; std::swap(cb, write_callback_); cb.Run(net::ERR_FAILED); } } void ChannelMultiplexer::MuxSocket::OnPacketReceived() { if (!read_callback_.is_null()) { int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); read_buffer_ = NULL; DCHECK_GT(result, 0); net::CompletionCallback cb; std::swap(cb, read_callback_); cb.Run(result); } } ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, const std::string& base_channel_name) : base_channel_factory_(factory), base_channel_name_(base_channel_name), next_channel_id_(0), weak_factory_(this) { } ChannelMultiplexer::~ChannelMultiplexer() { DCHECK(pending_channels_.empty()); STLDeleteValues(&channels_); // Cancel creation of the base channel if it hasn't finished. if (base_channel_factory_) base_channel_factory_->CancelChannelCreation(base_channel_name_); } void ChannelMultiplexer::CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) { if (base_channel_.get()) { // Already have |base_channel_|. Create new multiplexed channel // synchronously. callback.Run(GetOrCreateChannel(name)->CreateSocket()); } else if (!base_channel_.get() && !base_channel_factory_) { // Fail synchronously if we failed to create |base_channel_|. callback.Run(scoped_ptr<net::StreamSocket>()); } else { // Still waiting for the |base_channel_|. pending_channels_.push_back(PendingChannel(name, callback)); // If this is the first multiplexed channel then create the base channel. if (pending_channels_.size() == 1U) { base_channel_factory_->CreateChannel( base_channel_name_, base::Bind(&ChannelMultiplexer::OnBaseChannelReady, base::Unretained(this))); } } } void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); it != pending_channels_.end(); ++it) { if (it->name == name) { pending_channels_.erase(it); return; } } } void ChannelMultiplexer::OnBaseChannelReady( scoped_ptr<net::StreamSocket> socket) { base_channel_factory_ = NULL; base_channel_ = socket.Pass(); if (base_channel_.get()) { // Initialize reader and writer. reader_.Init(base_channel_.get(), base::Bind(&ChannelMultiplexer::OnIncomingPacket, base::Unretained(this))); writer_.Init(base_channel_.get(), base::Bind(&ChannelMultiplexer::OnWriteFailed, base::Unretained(this))); } DoCreatePendingChannels(); } void ChannelMultiplexer::DoCreatePendingChannels() { if (pending_channels_.empty()) return; // Every time this function is called it connects a single channel and posts a // separate task to connect other channels. This is necessary because the // callback may destroy the multiplexer or somehow else modify // |pending_channels_| list (e.g. call CancelChannelCreation()). base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels, weak_factory_.GetWeakPtr())); PendingChannel c = pending_channels_.front(); pending_channels_.erase(pending_channels_.begin()); scoped_ptr<net::StreamSocket> socket; if (base_channel_.get()) socket = GetOrCreateChannel(c.name)->CreateSocket(); c.callback.Run(socket.Pass()); } ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( const std::string& name) { // Check if we already have a channel with the requested name. std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); if (it != channels_.end()) return it->second; // Create a new channel if we haven't found existing one. MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); ++next_channel_id_; channels_[channel->name()] = channel; return channel; } void ChannelMultiplexer::OnWriteFailed(int error) { for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); it != channels_.end(); ++it) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, weak_factory_.GetWeakPtr(), it->second->name())); } } void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); if (it != channels_.end()) { it->second->OnWriteFailed(); } } void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task) { DCHECK(packet->has_channel_id()); if (!packet->has_channel_id()) { LOG(ERROR) << "Received packet without channel_id."; done_task.Run(); return; } int receive_id = packet->channel_id(); MuxChannel* channel = NULL; std::map<int, MuxChannel*>::iterator it = channels_by_receive_id_.find(receive_id); if (it != channels_by_receive_id_.end()) { channel = it->second; } else { // This is a new |channel_id| we haven't seen before. Look it up by name. if (!packet->has_channel_name()) { LOG(ERROR) << "Received packet with unknown channel_id and " "without channel_name."; done_task.Run(); return; } channel = GetOrCreateChannel(packet->channel_name()); channel->set_receive_id(receive_id); channels_by_receive_id_[receive_id] = channel; } channel->OnIncomingPacket(packet.Pass(), done_task); } bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, const base::Closure& done_task) { return writer_.Write(SerializeAndFrameMessage(*packet), done_task); } } // namespace protocol } // namespace remoting