// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "content/browser/renderer_host/websocket_host.h" #include "base/basictypes.h" #include "base/strings/string_util.h" #include "content/browser/renderer_host/websocket_dispatcher_host.h" #include "content/common/websocket_messages.h" #include "ipc/ipc_message_macros.h" #include "net/websockets/websocket_channel.h" #include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode namespace content { namespace { typedef net::WebSocketEventInterface::ChannelState ChannelState; // Convert a content::WebSocketMessageType to a // net::WebSocketFrameHeader::OpCode net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( WebSocketMessageType type) { DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION || type == WEB_SOCKET_MESSAGE_TYPE_TEXT || type == WEB_SOCKET_MESSAGE_TYPE_BINARY); typedef net::WebSocketFrameHeader::OpCode OpCode; // These compile asserts verify that the same underlying values are used for // both types, so we can simply cast between them. COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) == net::WebSocketFrameHeader::kOpCodeContinuation, enum_values_must_match_for_opcode_continuation); COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) == net::WebSocketFrameHeader::kOpCodeText, enum_values_must_match_for_opcode_text); COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) == net::WebSocketFrameHeader::kOpCodeBinary, enum_values_must_match_for_opcode_binary); return static_cast<OpCode>(type); } WebSocketMessageType OpCodeToMessageType( net::WebSocketFrameHeader::OpCode opCode) { DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || opCode == net::WebSocketFrameHeader::kOpCodeText || opCode == net::WebSocketFrameHeader::kOpCodeBinary); // This cast is guaranteed valid by the COMPILE_ASSERT() statements above. return static_cast<WebSocketMessageType>(opCode); } ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE = WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE; const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED = WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED; DCHECK(host_state == WEBSOCKET_HOST_ALIVE || host_state == WEBSOCKET_HOST_DELETED); // These compile asserts verify that we can get away with using static_cast<> // for the conversion. COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) == net::WebSocketEventInterface::CHANNEL_ALIVE, enum_values_must_match_for_state_alive); COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) == net::WebSocketEventInterface::CHANNEL_DELETED, enum_values_must_match_for_state_deleted); return static_cast<ChannelState>(host_state); } // Implementation of net::WebSocketEventInterface. Receives events from our // WebSocketChannel object. Each event is translated to an IPC and sent to the // renderer or child process via WebSocketDispatcherHost. class WebSocketEventHandler : public net::WebSocketEventInterface { public: WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, int routing_id); virtual ~WebSocketEventHandler(); // net::WebSocketEventInterface implementation // TODO(ricea): Add |extensions| parameter to pass the list of enabled // WebSocket extensions through to the renderer to make it visible to // Javascript. virtual ChannelState OnAddChannelResponse( bool fail, const std::string& selected_subprotocol) OVERRIDE; virtual ChannelState OnDataFrame(bool fin, WebSocketMessageType type, const std::vector<char>& data) OVERRIDE; virtual ChannelState OnClosingHandshake() OVERRIDE; virtual ChannelState OnFlowControl(int64 quota) OVERRIDE; virtual ChannelState OnDropChannel(uint16 code, const std::string& reason) OVERRIDE; private: WebSocketDispatcherHost* const dispatcher_; const int routing_id_; DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); }; WebSocketEventHandler::WebSocketEventHandler( WebSocketDispatcherHost* dispatcher, int routing_id) : dispatcher_(dispatcher), routing_id_(routing_id) {} WebSocketEventHandler::~WebSocketEventHandler() { DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; } ChannelState WebSocketEventHandler::OnAddChannelResponse( bool fail, const std::string& selected_protocol) { DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse" << " routing_id=" << routing_id_ << " fail=" << fail << " selected_protocol=\"" << selected_protocol << "\""; return StateCast(dispatcher_->SendAddChannelResponse( routing_id_, fail, selected_protocol, std::string())); } ChannelState WebSocketEventHandler::OnDataFrame( bool fin, net::WebSocketFrameHeader::OpCode type, const std::vector<char>& data) { DVLOG(3) << "WebSocketEventHandler::OnDataFrame" << " routing_id=" << routing_id_ << " fin=" << fin << " type=" << type << " data is " << data.size() << " bytes"; return StateCast(dispatcher_->SendFrame( routing_id_, fin, OpCodeToMessageType(type), data)); } ChannelState WebSocketEventHandler::OnClosingHandshake() { DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake" << " routing_id=" << routing_id_; return StateCast(dispatcher_->SendClosing(routing_id_)); } ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) { DVLOG(3) << "WebSocketEventHandler::OnFlowControl" << " routing_id=" << routing_id_ << " quota=" << quota; return StateCast(dispatcher_->SendFlowControl(routing_id_, quota)); } ChannelState WebSocketEventHandler::OnDropChannel(uint16 code, const std::string& reason) { DVLOG(3) << "WebSocketEventHandler::OnDropChannel" << " routing_id=" << routing_id_ << " code=" << code << " reason=\"" << reason << "\""; return StateCast(dispatcher_->DoDropChannel(routing_id_, code, reason)); } } // namespace WebSocketHost::WebSocketHost(int routing_id, WebSocketDispatcherHost* dispatcher, net::URLRequestContext* url_request_context) : routing_id_(routing_id) { DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id; scoped_ptr<net::WebSocketEventInterface> event_interface( new WebSocketEventHandler(dispatcher, routing_id)); channel_.reset( new net::WebSocketChannel(event_interface.Pass(), url_request_context)); } WebSocketHost::~WebSocketHost() {} bool WebSocketHost::OnMessageReceived(const IPC::Message& message, bool* message_was_ok) { bool handled = true; IPC_BEGIN_MESSAGE_MAP_EX(WebSocketHost, message, *message_was_ok) IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest) IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame) IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl) IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel) IPC_MESSAGE_UNHANDLED(handled = false) IPC_END_MESSAGE_MAP_EX() return handled; } void WebSocketHost::OnAddChannelRequest( const GURL& socket_url, const std::vector<std::string>& requested_protocols, const GURL& origin) { DVLOG(3) << "WebSocketHost::OnAddChannelRequest" << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url << "\" requested_protocols=\"" << JoinString(requested_protocols, ", ") << "\" origin=\"" << origin << "\""; channel_->SendAddChannelRequest(socket_url, requested_protocols, origin); } void WebSocketHost::OnSendFrame(bool fin, WebSocketMessageType type, const std::vector<char>& data) { DVLOG(3) << "WebSocketHost::OnSendFrame" << " routing_id=" << routing_id_ << " fin=" << fin << " type=" << type << " data is " << data.size() << " bytes"; channel_->SendFrame(fin, MessageTypeToOpCode(type), data); } void WebSocketHost::OnFlowControl(int64 quota) { DVLOG(3) << "WebSocketHost::OnFlowControl" << " routing_id=" << routing_id_ << " quota=" << quota; channel_->SendFlowControl(quota); } void WebSocketHost::OnDropChannel(bool was_clean, uint16 code, const std::string& reason) { DVLOG(3) << "WebSocketHost::OnDropChannel" << " routing_id=" << routing_id_ << " was_clean=" << was_clean << " code=" << code << " reason=\"" << reason << "\""; // TODO(yhirano): Handle |was_clean| appropriately. channel_->StartClosingHandshake(code, reason); } } // namespace content