// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_job.h"
#include <algorithm>
#include "base/bind.h"
#include "base/lazy_instance.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/net_log.h"
#include "net/cookies/cookie_store.h"
#include "net/http/http_network_session.h"
#include "net/http/http_transaction_factory.h"
#include "net/http/http_util.h"
#include "net/spdy/spdy_session.h"
#include "net/spdy/spdy_session_pool.h"
#include "net/url_request/url_request_context.h"
#include "net/websockets/websocket_handshake_handler.h"
#include "net/websockets/websocket_net_log_params.h"
#include "net/websockets/websocket_throttle.h"
#include "url/gurl.h"
static const int kMaxPendingSendAllowed = 32768; // 32 kilobytes.
namespace {
// lower-case header names.
const char* const kCookieHeaders[] = {
"cookie", "cookie2"
};
const char* const kSetCookieHeaders[] = {
"set-cookie", "set-cookie2"
};
net::SocketStreamJob* WebSocketJobFactory(
const GURL& url, net::SocketStream::Delegate* delegate,
net::URLRequestContext* context, net::CookieStore* cookie_store) {
net::WebSocketJob* job = new net::WebSocketJob(delegate);
job->InitSocketStream(new net::SocketStream(url, job, context, cookie_store));
return job;
}
class WebSocketJobInitSingleton {
private:
friend struct base::DefaultLazyInstanceTraits<WebSocketJobInitSingleton>;
WebSocketJobInitSingleton() {
net::SocketStreamJob::RegisterProtocolFactory("ws", WebSocketJobFactory);
net::SocketStreamJob::RegisterProtocolFactory("wss", WebSocketJobFactory);
}
};
static base::LazyInstance<WebSocketJobInitSingleton> g_websocket_job_init =
LAZY_INSTANCE_INITIALIZER;
} // anonymous namespace
namespace net {
// static
void WebSocketJob::EnsureInit() {
g_websocket_job_init.Get();
}
WebSocketJob::WebSocketJob(SocketStream::Delegate* delegate)
: delegate_(delegate),
state_(INITIALIZED),
waiting_(false),
handshake_request_(new WebSocketHandshakeRequestHandler),
handshake_response_(new WebSocketHandshakeResponseHandler),
started_to_send_handshake_request_(false),
handshake_request_sent_(0),
response_cookies_save_index_(0),
spdy_protocol_version_(0),
save_next_cookie_running_(false),
callback_pending_(false),
weak_ptr_factory_(this),
weak_ptr_factory_for_send_pending_(this) {
}
WebSocketJob::~WebSocketJob() {
DCHECK_EQ(CLOSED, state_);
DCHECK(!delegate_);
DCHECK(!socket_.get());
}
void WebSocketJob::Connect() {
DCHECK(socket_.get());
DCHECK_EQ(state_, INITIALIZED);
state_ = CONNECTING;
socket_->Connect();
}
bool WebSocketJob::SendData(const char* data, int len) {
switch (state_) {
case INITIALIZED:
return false;
case CONNECTING:
return SendHandshakeRequest(data, len);
case OPEN:
{
scoped_refptr<IOBufferWithSize> buffer = new IOBufferWithSize(len);
memcpy(buffer->data(), data, len);
if (current_send_buffer_.get() || !send_buffer_queue_.empty()) {
send_buffer_queue_.push_back(buffer);
return true;
}
current_send_buffer_ = new DrainableIOBuffer(buffer.get(), len);
return SendDataInternal(current_send_buffer_->data(),
current_send_buffer_->BytesRemaining());
}
case CLOSING:
case CLOSED:
return false;
}
return false;
}
void WebSocketJob::Close() {
if (state_ == CLOSED)
return;
state_ = CLOSING;
if (current_send_buffer_.get()) {
// Will close in SendPending.
return;
}
state_ = CLOSED;
CloseInternal();
}
void WebSocketJob::RestartWithAuth(const AuthCredentials& credentials) {
state_ = CONNECTING;
socket_->RestartWithAuth(credentials);
}
void WebSocketJob::DetachDelegate() {
state_ = CLOSED;
WebSocketThrottle::GetInstance()->RemoveFromQueue(this);
scoped_refptr<WebSocketJob> protect(this);
weak_ptr_factory_.InvalidateWeakPtrs();
weak_ptr_factory_for_send_pending_.InvalidateWeakPtrs();
delegate_ = NULL;
if (socket_.get())
socket_->DetachDelegate();
socket_ = NULL;
if (!callback_.is_null()) {
waiting_ = false;
callback_.Reset();
Release(); // Balanced with OnStartOpenConnection().
}
}
int WebSocketJob::OnStartOpenConnection(
SocketStream* socket, const CompletionCallback& callback) {
DCHECK(callback_.is_null());
state_ = CONNECTING;
addresses_ = socket->address_list();
if (!WebSocketThrottle::GetInstance()->PutInQueue(this)) {
return ERR_WS_THROTTLE_QUEUE_TOO_LARGE;
}
if (delegate_) {
int result = delegate_->OnStartOpenConnection(socket, callback);
DCHECK_EQ(OK, result);
}
if (waiting_) {
// PutInQueue() may set |waiting_| true for throttling. In this case,
// Wakeup() will be called later.
callback_ = callback;
AddRef(); // Balanced when callback_ is cleared.
return ERR_IO_PENDING;
}
return TrySpdyStream();
}
void WebSocketJob::OnConnected(
SocketStream* socket, int max_pending_send_allowed) {
if (state_ == CLOSED)
return;
DCHECK_EQ(CONNECTING, state_);
if (delegate_)
delegate_->OnConnected(socket, max_pending_send_allowed);
}
void WebSocketJob::OnSentData(SocketStream* socket, int amount_sent) {
DCHECK_NE(INITIALIZED, state_);
DCHECK_GT(amount_sent, 0);
if (state_ == CLOSED)
return;
if (state_ == CONNECTING) {
OnSentHandshakeRequest(socket, amount_sent);
return;
}
if (delegate_) {
DCHECK(state_ == OPEN || state_ == CLOSING);
if (!current_send_buffer_.get()) {
VLOG(1)
<< "OnSentData current_send_buffer=NULL amount_sent=" << amount_sent;
return;
}
current_send_buffer_->DidConsume(amount_sent);
if (current_send_buffer_->BytesRemaining() > 0)
return;
// We need to report amount_sent of original buffer size, instead of
// amount sent to |socket|.
amount_sent = current_send_buffer_->size();
DCHECK_GT(amount_sent, 0);
current_send_buffer_ = NULL;
if (!weak_ptr_factory_for_send_pending_.HasWeakPtrs()) {
base::MessageLoopForIO::current()->PostTask(
FROM_HERE,
base::Bind(&WebSocketJob::SendPending,
weak_ptr_factory_for_send_pending_.GetWeakPtr()));
}
delegate_->OnSentData(socket, amount_sent);
}
}
void WebSocketJob::OnReceivedData(
SocketStream* socket, const char* data, int len) {
DCHECK_NE(INITIALIZED, state_);
if (state_ == CLOSED)
return;
if (state_ == CONNECTING) {
OnReceivedHandshakeResponse(socket, data, len);
return;
}
DCHECK(state_ == OPEN || state_ == CLOSING);
if (delegate_ && len > 0)
delegate_->OnReceivedData(socket, data, len);
}
void WebSocketJob::OnClose(SocketStream* socket) {
state_ = CLOSED;
WebSocketThrottle::GetInstance()->RemoveFromQueue(this);
scoped_refptr<WebSocketJob> protect(this);
weak_ptr_factory_.InvalidateWeakPtrs();
SocketStream::Delegate* delegate = delegate_;
delegate_ = NULL;
socket_ = NULL;
if (!callback_.is_null()) {
waiting_ = false;
callback_.Reset();
Release(); // Balanced with OnStartOpenConnection().
}
if (delegate)
delegate->OnClose(socket);
}
void WebSocketJob::OnAuthRequired(
SocketStream* socket, AuthChallengeInfo* auth_info) {
if (delegate_)
delegate_->OnAuthRequired(socket, auth_info);
}
void WebSocketJob::OnSSLCertificateError(
SocketStream* socket, const SSLInfo& ssl_info, bool fatal) {
if (delegate_)
delegate_->OnSSLCertificateError(socket, ssl_info, fatal);
}
void WebSocketJob::OnError(const SocketStream* socket, int error) {
if (delegate_ && error != ERR_PROTOCOL_SWITCHED)
delegate_->OnError(socket, error);
}
void WebSocketJob::OnCreatedSpdyStream(int result) {
DCHECK(spdy_websocket_stream_.get());
DCHECK(socket_.get());
DCHECK_NE(ERR_IO_PENDING, result);
if (state_ == CLOSED) {
result = ERR_ABORTED;
} else if (result == OK) {
state_ = CONNECTING;
result = ERR_PROTOCOL_SWITCHED;
} else {
spdy_websocket_stream_.reset();
}
CompleteIO(result);
}
void WebSocketJob::OnSentSpdyHeaders() {
DCHECK_NE(INITIALIZED, state_);
if (state_ != CONNECTING)
return;
size_t original_length = handshake_request_->original_length();
handshake_request_.reset();
if (delegate_)
delegate_->OnSentData(socket_.get(), original_length);
}
void WebSocketJob::OnSpdyResponseHeadersUpdated(
const SpdyHeaderBlock& response_headers) {
DCHECK_NE(INITIALIZED, state_);
if (state_ != CONNECTING)
return;
// TODO(toyoshim): Fallback to non-spdy connection?
handshake_response_->ParseResponseHeaderBlock(response_headers,
challenge_,
spdy_protocol_version_);
SaveCookiesAndNotifyHeadersComplete();
}
void WebSocketJob::OnSentSpdyData(size_t bytes_sent) {
DCHECK_NE(INITIALIZED, state_);
DCHECK_NE(CONNECTING, state_);
if (state_ == CLOSED)
return;
if (!spdy_websocket_stream_.get())
return;
OnSentData(socket_.get(), static_cast<int>(bytes_sent));
}
void WebSocketJob::OnReceivedSpdyData(scoped_ptr<SpdyBuffer> buffer) {
DCHECK_NE(INITIALIZED, state_);
DCHECK_NE(CONNECTING, state_);
if (state_ == CLOSED)
return;
if (!spdy_websocket_stream_.get())
return;
if (buffer) {
OnReceivedData(
socket_.get(), buffer->GetRemainingData(), buffer->GetRemainingSize());
} else {
OnReceivedData(socket_.get(), NULL, 0);
}
}
void WebSocketJob::OnCloseSpdyStream() {
spdy_websocket_stream_.reset();
OnClose(socket_.get());
}
bool WebSocketJob::SendHandshakeRequest(const char* data, int len) {
DCHECK_EQ(state_, CONNECTING);
if (started_to_send_handshake_request_)
return false;
if (!handshake_request_->ParseRequest(data, len))
return false;
AddCookieHeaderAndSend();
return true;
}
void WebSocketJob::AddCookieHeaderAndSend() {
bool allow = true;
if (delegate_ && !delegate_->CanGetCookies(socket_.get(), GetURLForCookies()))
allow = false;
if (socket_.get() && delegate_ && state_ == CONNECTING) {
handshake_request_->RemoveHeaders(kCookieHeaders,
arraysize(kCookieHeaders));
if (allow && socket_->cookie_store()) {
// Add cookies, including HttpOnly cookies.
CookieOptions cookie_options;
cookie_options.set_include_httponly();
socket_->cookie_store()->GetCookiesWithOptionsAsync(
GetURLForCookies(), cookie_options,
base::Bind(&WebSocketJob::LoadCookieCallback,
weak_ptr_factory_.GetWeakPtr()));
} else {
DoSendData();
}
}
}
void WebSocketJob::LoadCookieCallback(const std::string& cookie) {
if (!cookie.empty())
// TODO(tyoshino): Sending cookie means that connection doesn't need
// PRIVACY_MODE_ENABLED as cookies may be server-bound and channel id
// wouldn't negatively affect privacy anyway. Need to restart connection
// or refactor to determine cookie status prior to connecting.
handshake_request_->AppendHeaderIfMissing("Cookie", cookie);
DoSendData();
}
void WebSocketJob::DoSendData() {
if (spdy_websocket_stream_.get()) {
scoped_ptr<SpdyHeaderBlock> headers(new SpdyHeaderBlock);
handshake_request_->GetRequestHeaderBlock(
socket_->url(), headers.get(), &challenge_, spdy_protocol_version_);
spdy_websocket_stream_->SendRequest(headers.Pass());
} else {
const std::string& handshake_request =
handshake_request_->GetRawRequest();
handshake_request_sent_ = 0;
socket_->net_log()->AddEvent(
NetLog::TYPE_WEB_SOCKET_SEND_REQUEST_HEADERS,
base::Bind(&NetLogWebSocketHandshakeCallback, &handshake_request));
socket_->SendData(handshake_request.data(),
handshake_request.size());
}
// Just buffered in |handshake_request_|.
started_to_send_handshake_request_ = true;
}
void WebSocketJob::OnSentHandshakeRequest(
SocketStream* socket, int amount_sent) {
DCHECK_EQ(state_, CONNECTING);
handshake_request_sent_ += amount_sent;
DCHECK_LE(handshake_request_sent_, handshake_request_->raw_length());
if (handshake_request_sent_ >= handshake_request_->raw_length()) {
// handshake request has been sent.
// notify original size of handshake request to delegate.
// Reset the handshake_request_ first in case this object is deleted by the
// delegate.
size_t original_length = handshake_request_->original_length();
handshake_request_.reset();
if (delegate_)
delegate_->OnSentData(socket, original_length);
}
}
void WebSocketJob::OnReceivedHandshakeResponse(
SocketStream* socket, const char* data, int len) {
DCHECK_EQ(state_, CONNECTING);
if (handshake_response_->HasResponse()) {
// If we already has handshake response, received data should be frame
// data, not handshake message.
received_data_after_handshake_.insert(
received_data_after_handshake_.end(), data, data + len);
return;
}
size_t response_length = handshake_response_->ParseRawResponse(data, len);
if (!handshake_response_->HasResponse()) {
// not yet. we need more data.
return;
}
// handshake message is completed.
std::string raw_response = handshake_response_->GetRawResponse();
socket_->net_log()->AddEvent(
NetLog::TYPE_WEB_SOCKET_READ_RESPONSE_HEADERS,
base::Bind(&NetLogWebSocketHandshakeCallback, &raw_response));
if (len - response_length > 0) {
// If we received extra data, it should be frame data.
DCHECK(received_data_after_handshake_.empty());
received_data_after_handshake_.assign(data + response_length, data + len);
}
SaveCookiesAndNotifyHeadersComplete();
}
void WebSocketJob::SaveCookiesAndNotifyHeadersComplete() {
// handshake message is completed.
DCHECK(handshake_response_->HasResponse());
// Extract cookies from the handshake response into a temporary vector.
response_cookies_.clear();
response_cookies_save_index_ = 0;
handshake_response_->GetHeaders(
kSetCookieHeaders, arraysize(kSetCookieHeaders), &response_cookies_);
// Now, loop over the response cookies, and attempt to persist each.
SaveNextCookie();
}
void WebSocketJob::NotifyHeadersComplete() {
// Remove cookie headers, with malformed headers preserved.
// Actual handshake should be done in Blink.
handshake_response_->RemoveHeaders(
kSetCookieHeaders, arraysize(kSetCookieHeaders));
std::string handshake_response = handshake_response_->GetResponse();
handshake_response_.reset();
std::vector<char> received_data(handshake_response.begin(),
handshake_response.end());
received_data.insert(received_data.end(),
received_data_after_handshake_.begin(),
received_data_after_handshake_.end());
received_data_after_handshake_.clear();
state_ = OPEN;
DCHECK(!received_data.empty());
if (delegate_)
delegate_->OnReceivedData(
socket_.get(), &received_data.front(), received_data.size());
WebSocketThrottle::GetInstance()->RemoveFromQueue(this);
}
void WebSocketJob::SaveNextCookie() {
if (!socket_.get() || !delegate_ || state_ != CONNECTING)
return;
callback_pending_ = false;
save_next_cookie_running_ = true;
if (socket_->cookie_store()) {
GURL url_for_cookies = GetURLForCookies();
CookieOptions options;
options.set_include_httponly();
// Loop as long as SetCookieWithOptionsAsync completes synchronously. Since
// CookieMonster's asynchronous operation APIs queue the callback to run it
// on the thread where the API was called, there won't be race. I.e. unless
// the callback is run synchronously, it won't be run in parallel with this
// method.
while (!callback_pending_ &&
response_cookies_save_index_ < response_cookies_.size()) {
std::string cookie = response_cookies_[response_cookies_save_index_];
response_cookies_save_index_++;
if (!delegate_->CanSetCookie(
socket_.get(), url_for_cookies, cookie, &options))
continue;
callback_pending_ = true;
socket_->cookie_store()->SetCookieWithOptionsAsync(
url_for_cookies, cookie, options,
base::Bind(&WebSocketJob::OnCookieSaved,
weak_ptr_factory_.GetWeakPtr()));
}
}
save_next_cookie_running_ = false;
if (callback_pending_)
return;
response_cookies_.clear();
response_cookies_save_index_ = 0;
NotifyHeadersComplete();
}
void WebSocketJob::OnCookieSaved(bool cookie_status) {
// Tell the caller of SetCookieWithOptionsAsync() that this completion
// callback is invoked.
// - If the caller checks callback_pending earlier than this callback, the
// caller exits to let this method continue iteration.
// - Otherwise, the caller continues iteration.
callback_pending_ = false;
// Resume SaveNextCookie if the caller of SetCookieWithOptionsAsync() exited
// the loop. Otherwise, return.
if (save_next_cookie_running_)
return;
SaveNextCookie();
}
GURL WebSocketJob::GetURLForCookies() const {
GURL url = socket_->url();
std::string scheme = socket_->is_secure() ? "https" : "http";
url::Replacements<char> replacements;
replacements.SetScheme(scheme.c_str(), url::Component(0, scheme.length()));
return url.ReplaceComponents(replacements);
}
const AddressList& WebSocketJob::address_list() const {
return addresses_;
}
int WebSocketJob::TrySpdyStream() {
if (!socket_.get())
return ERR_FAILED;
// Check if we have a SPDY session available.
HttpTransactionFactory* factory =
socket_->context()->http_transaction_factory();
if (!factory)
return OK;
scoped_refptr<HttpNetworkSession> session = factory->GetSession();
if (!session.get() || !session->params().enable_websocket_over_spdy)
return OK;
SpdySessionPool* spdy_pool = session->spdy_session_pool();
PrivacyMode privacy_mode = socket_->privacy_mode();
const SpdySessionKey key(HostPortPair::FromURL(socket_->url()),
socket_->proxy_server(), privacy_mode);
// Forbid wss downgrade to SPDY without SSL.
// TODO(toyoshim): Does it realize the same policy with HTTP?
base::WeakPtr<SpdySession> spdy_session =
spdy_pool->FindAvailableSession(key, *socket_->net_log());
if (!spdy_session)
return OK;
SSLInfo ssl_info;
bool was_npn_negotiated;
NextProto protocol_negotiated = kProtoUnknown;
bool use_ssl = spdy_session->GetSSLInfo(
&ssl_info, &was_npn_negotiated, &protocol_negotiated);
if (socket_->is_secure() && !use_ssl)
return OK;
// Create SpdyWebSocketStream.
spdy_protocol_version_ = spdy_session->GetProtocolVersion();
spdy_websocket_stream_.reset(new SpdyWebSocketStream(spdy_session, this));
int result = spdy_websocket_stream_->InitializeStream(
socket_->url(), MEDIUM, *socket_->net_log());
if (result == OK) {
OnConnected(socket_.get(), kMaxPendingSendAllowed);
return ERR_PROTOCOL_SWITCHED;
}
if (result != ERR_IO_PENDING) {
spdy_websocket_stream_.reset();
return OK;
}
return ERR_IO_PENDING;
}
void WebSocketJob::SetWaiting() {
waiting_ = true;
}
bool WebSocketJob::IsWaiting() const {
return waiting_;
}
void WebSocketJob::Wakeup() {
if (!waiting_)
return;
waiting_ = false;
DCHECK(!callback_.is_null());
base::MessageLoopForIO::current()->PostTask(
FROM_HERE,
base::Bind(&WebSocketJob::RetryPendingIO,
weak_ptr_factory_.GetWeakPtr()));
}
void WebSocketJob::RetryPendingIO() {
int result = TrySpdyStream();
// In the case of ERR_IO_PENDING, CompleteIO() will be called from
// OnCreatedSpdyStream().
if (result != ERR_IO_PENDING)
CompleteIO(result);
}
void WebSocketJob::CompleteIO(int result) {
// |callback_| may be null if OnClose() or DetachDelegate() was called.
if (!callback_.is_null()) {
CompletionCallback callback = callback_;
callback_.Reset();
callback.Run(result);
Release(); // Balanced with OnStartOpenConnection().
}
}
bool WebSocketJob::SendDataInternal(const char* data, int length) {
if (spdy_websocket_stream_.get())
return ERR_IO_PENDING == spdy_websocket_stream_->SendData(data, length);
if (socket_.get())
return socket_->SendData(data, length);
return false;
}
void WebSocketJob::CloseInternal() {
if (spdy_websocket_stream_.get())
spdy_websocket_stream_->Close();
if (socket_.get())
socket_->Close();
}
void WebSocketJob::SendPending() {
if (current_send_buffer_.get())
return;
// Current buffer has been sent. Try next if any.
if (send_buffer_queue_.empty()) {
// No more data to send.
if (state_ == CLOSING)
CloseInternal();
return;
}
scoped_refptr<IOBufferWithSize> next_buffer = send_buffer_queue_.front();
send_buffer_queue_.pop_front();
current_send_buffer_ =
new DrainableIOBuffer(next_buffer.get(), next_buffer->size());
SendDataInternal(current_send_buffer_->data(),
current_send_buffer_->BytesRemaining());
}
} // namespace net