// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_handshake.h"
#include <algorithm>
#include <vector>
#include "base/logging.h"
#include "base/md5.h"
#include "base/memory/ref_counted.h"
#include "base/rand_util.h"
#include "base/string_number_conversions.h"
#include "base/string_util.h"
#include "base/stringprintf.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
namespace net {
const int WebSocketHandshake::kWebSocketPort = 80;
const int WebSocketHandshake::kSecureWebSocketPort = 443;
WebSocketHandshake::WebSocketHandshake(
const GURL& url,
const std::string& origin,
const std::string& location,
const std::string& protocol)
: url_(url),
origin_(origin),
location_(location),
protocol_(protocol),
mode_(MODE_INCOMPLETE) {
}
WebSocketHandshake::~WebSocketHandshake() {
}
bool WebSocketHandshake::is_secure() const {
return url_.SchemeIs("wss");
}
std::string WebSocketHandshake::CreateClientHandshakeMessage() {
if (!parameter_.get()) {
parameter_.reset(new Parameter);
parameter_->GenerateKeys();
}
std::string msg;
// WebSocket protocol 4.1 Opening handshake.
msg = "GET ";
msg += GetResourceName();
msg += " HTTP/1.1\r\n";
std::vector<std::string> fields;
fields.push_back("Upgrade: WebSocket");
fields.push_back("Connection: Upgrade");
fields.push_back("Host: " + GetHostFieldValue());
fields.push_back("Origin: " + GetOriginFieldValue());
if (!protocol_.empty())
fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
// TODO(ukai): Add cookie if necessary.
fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator);
for (size_t i = 0; i < fields.size(); i++) {
msg += fields[i] + "\r\n";
}
msg += "\r\n";
msg.append(parameter_->GetKey3());
return msg;
}
int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
mode_ = MODE_INCOMPLETE;
int eoh = HttpUtil::LocateEndOfHeaders(data, len);
if (eoh < 0)
return -1;
scoped_refptr<HttpResponseHeaders> headers(
new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
if (headers->response_code() != 101) {
mode_ = MODE_FAILED;
DVLOG(1) << "Bad response code: " << headers->response_code();
return eoh;
}
mode_ = MODE_NORMAL;
if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
mode_ = MODE_FAILED;
return eoh;
}
if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
mode_ = MODE_INCOMPLETE;
return -1;
}
uint8 expected[Parameter::kExpectedResponseSize];
parameter_->GetExpectedResponse(expected);
if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
mode_ = MODE_FAILED;
return eoh + Parameter::kExpectedResponseSize;
}
mode_ = MODE_CONNECTED;
return eoh + Parameter::kExpectedResponseSize;
}
std::string WebSocketHandshake::GetResourceName() const {
std::string resource_name = url_.path();
if (url_.has_query()) {
resource_name += "?";
resource_name += url_.query();
}
return resource_name;
}
std::string WebSocketHandshake::GetHostFieldValue() const {
// url_.host() is expected to be encoded in punnycode here.
std::string host = StringToLowerASCII(url_.host());
if (url_.has_port()) {
bool secure = is_secure();
int port = url_.EffectiveIntPort();
if ((!secure &&
port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
(secure &&
port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
host += ":";
host += base::IntToString(port);
}
}
return host;
}
std::string WebSocketHandshake::GetOriginFieldValue() const {
// It's OK to lowercase the origin as the Origin header does not contain
// the path or query portions, as per
// http://tools.ietf.org/html/draft-abarth-origin-00.
//
// TODO(satorux): Should we trim the port portion here if it's 80 for
// http:// or 443 for https:// ? Or can we assume it's done by the
// client of the library?
return StringToLowerASCII(origin_);
}
/* static */
bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
const std::string& name,
std::string* value) {
std::string first_value;
void* iter = NULL;
if (!headers.EnumerateHeader(&iter, name, &first_value))
return false;
// Checks no more |name| found in |headers|.
// Second call of EnumerateHeader() must return false.
std::string second_value;
if (headers.EnumerateHeader(&iter, name, &second_value))
return false;
*value = first_value;
return true;
}
bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
std::string value;
if (!GetSingleHeader(headers, "upgrade", &value) ||
value != "WebSocket")
return false;
if (!GetSingleHeader(headers, "connection", &value) ||
!LowerCaseEqualsASCII(value, "upgrade"))
return false;
if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
return false;
if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
return false;
// If |protocol_| is not specified by client, we don't care if there's
// protocol field or not as specified in the spec.
if (!protocol_.empty()
&& !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
return false;
return true;
}
bool WebSocketHandshake::CheckResponseHeaders() const {
DCHECK(mode_ == MODE_NORMAL);
if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
return false;
if (location_ != ws_location_)
return false;
if (!protocol_.empty() && protocol_ != ws_protocol_)
return false;
return true;
}
namespace {
// unsigned int version of base::RandInt().
// we can't use base::RandInt(), because max would be negative if it is
// represented as int, so DCHECK(min <= max) fails.
uint32 RandUint32(uint32 min, uint32 max) {
DCHECK(min <= max);
uint64 range = static_cast<int64>(max) - min + 1;
uint64 number = base::RandUint64();
// TODO(ukai): fix to be uniform.
// the distribution of the result of modulo will be biased.
uint32 result = min + static_cast<uint32>(number % range);
DCHECK(result >= min && result <= max);
return result;
}
}
uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
RandUint32;
uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
WebSocketHandshake::Parameter::Parameter()
: number_1_(0), number_2_(0) {
if (randomCharacterInSecWebSocketKey[0] == '\0') {
int i = 0;
for (int ch = 0x21; ch <= 0x2F; ch++, i++)
randomCharacterInSecWebSocketKey[i] = ch;
for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
randomCharacterInSecWebSocketKey[i] = ch;
}
}
WebSocketHandshake::Parameter::~Parameter() {}
void WebSocketHandshake::Parameter::GenerateKeys() {
GenerateSecWebSocketKey(&number_1_, &key_1_);
GenerateSecWebSocketKey(&number_2_, &key_2_);
GenerateKey3();
}
static void SetChallengeNumber(uint8* buf, uint32 number) {
uint8* p = buf + 3;
for (int i = 0; i < 4; i++) {
*p = (uint8)(number & 0xFF);
--p;
number >>= 8;
}
}
void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
uint8 challenge[kExpectedResponseSize];
SetChallengeNumber(&challenge[0], number_1_);
SetChallengeNumber(&challenge[4], number_2_);
memcpy(&challenge[8], key_3_.data(), kKey3Size);
MD5Digest digest;
MD5Sum(challenge, kExpectedResponseSize, &digest);
memcpy(expected, digest.a, kExpectedResponseSize);
}
/* static */
void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
uint32 (*rand)(uint32 min, uint32 max)) {
rand_ = rand;
}
void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
uint32* number, std::string* key) {
uint32 space = rand_(1, 12);
uint32 max = 4294967295U / space;
*number = rand_(0, max);
uint32 product = *number * space;
std::string s = base::StringPrintf("%u", product);
int n = rand_(1, 12);
for (int i = 0; i < n; i++) {
int pos = rand_(0, s.length());
int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
s.substr(pos);
}
for (uint32 i = 0; i < space; i++) {
int pos = rand_(1, s.length() - 1);
s = s.substr(0, pos) + " " + s.substr(pos);
}
*key = s;
}
void WebSocketHandshake::Parameter::GenerateKey3() {
key_3_.clear();
for (int i = 0; i < 8; i++) {
key_3_.append(1, rand_(0, 255));
}
}
} // namespace net