// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "chrome_frame/test/test_server.h" #include <windows.h> #include <objbase.h> #include <urlmon.h> #include "base/bind.h" #include "base/logging.h" #include "base/strings/string_number_conversions.h" #include "base/strings/string_piece.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" #include "chrome_frame/test/chrome_frame_test_utils.h" #include "net/base/winsock_init.h" #include "net/http/http_util.h" #include "net/socket/tcp_listen_socket.h" namespace test_server { const char kDefaultHeaderTemplate[] = "HTTP/1.1 %hs\r\n" "Connection: close\r\n" "Content-Type: %hs\r\n" "Content-Length: %i\r\n\r\n"; const char kStatusOk[] = "200 OK"; const char kStatusNotFound[] = "404 Not Found"; const char kDefaultContentType[] = "text/html; charset=UTF-8"; void Request::ParseHeaders(const std::string& headers) { DCHECK(method_.length() == 0); size_t pos = headers.find("\r\n"); DCHECK(pos != std::string::npos); if (pos != std::string::npos) { headers_ = headers.substr(pos + 2); base::StringTokenizer tokenizer( headers.begin(), headers.begin() + pos, " "); std::string* parse[] = { &method_, &path_, &version_ }; int field = 0; while (tokenizer.GetNext() && field < arraysize(parse)) { parse[field++]->assign(tokenizer.token_begin(), tokenizer.token_end()); } } // Check for content-length in case we're being sent some data. net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(), "\r\n"); while (it.GetNext()) { if (LowerCaseEqualsASCII(it.name(), "content-length")) { int int_content_length; base::StringToInt(base::StringPiece(it.values_begin(), it.values_end()), &int_content_length); content_length_ = int_content_length; break; } } } void Request::OnDataReceived(const std::string& data) { content_ += data; if (method_.length() == 0) { size_t index = content_.find("\r\n\r\n"); if (index != std::string::npos) { // Parse the headers before returning and chop them of the // data buffer we've already received. std::string headers(content_.substr(0, index + 2)); ParseHeaders(headers); content_.erase(0, index + 4); } } } ResponseForPath::~ResponseForPath() { } SimpleResponse::~SimpleResponse() { } bool FileResponse::GetContentType(std::string* content_type) const { size_t length = ContentLength(); char buffer[4096]; void* data = NULL; if (length) { // Create a copy of the first few bytes of the file. // If we try and use the mapped file directly, FindMimeFromData will crash // 'cause it cheats and temporarily tries to write to the buffer! length = std::min(arraysize(buffer), length); memcpy(buffer, file_->data(), length); data = buffer; } LPOLESTR mime_type = NULL; FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL, FMFD_DEFAULT, &mime_type, 0); if (mime_type) { *content_type = WideToASCII(mime_type); ::CoTaskMemFree(mime_type); } return content_type->length() > 0; } void FileResponse::WriteContents(net::StreamListenSocket* socket) const { DCHECK(file_.get()); if (file_.get()) { socket->Send(reinterpret_cast<const char*>(file_->data()), file_->length(), false); } } size_t FileResponse::ContentLength() const { if (file_.get() == NULL) { file_.reset(new base::MemoryMappedFile()); if (!file_->Initialize(file_path_)) { NOTREACHED(); file_.reset(); } } return file_.get() ? file_->length() : 0; } bool RedirectResponse::GetCustomHeaders(std::string* headers) const { *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n" "Connection: close\r\n" "Content-Length: 0\r\n" "Content-Type: text/html\r\n" "Location: %hs\r\n\r\n", redirect_url_.c_str()); return true; } SimpleWebServer::SimpleWebServer(int port) { Construct(chrome_frame_test::GetLocalIPv4Address(), port); } SimpleWebServer::SimpleWebServer(const std::string& address, int port) { Construct(address, port); } SimpleWebServer::~SimpleWebServer() { ConnectionList::const_iterator it; for (it = connections_.begin(); it != connections_.end(); ++it) delete (*it); connections_.clear(); } void SimpleWebServer::Construct(const std::string& address, int port) { CHECK(base::MessageLoop::current()) << "SimpleWebServer requires a message loop"; net::EnsureWinsockInit(); AddResponse(&quit_); host_ = address; server_ = net::TCPListenSocket::CreateAndListen(address, port, this); LOG_IF(DFATAL, !server_.get()) << "Failed to create listener socket at " << address << ":" << port; } void SimpleWebServer::AddResponse(Response* response) { responses_.push_back(response); } void SimpleWebServer::DeleteAllResponses() { std::list<Response*>::const_iterator it; for (it = responses_.begin(); it != responses_.end(); ++it) { if ((*it) != &quit_) delete (*it); } } Response* SimpleWebServer::FindResponse(const Request& request) const { std::list<Response*>::const_iterator it; for (it = responses_.begin(); it != responses_.end(); it++) { Response* response = (*it); if (response->Matches(request)) { return response; } } return NULL; } Connection* SimpleWebServer::FindConnection( const net::StreamListenSocket* socket) const { ConnectionList::const_iterator it; for (it = connections_.begin(); it != connections_.end(); it++) { if ((*it)->IsSame(socket)) { return (*it); } } return NULL; } void SimpleWebServer::DidAccept( net::StreamListenSocket* server, scoped_ptr<net::StreamListenSocket> connection) { connections_.push_back(new Connection(connection.Pass())); } void SimpleWebServer::DidRead(net::StreamListenSocket* connection, const char* data, int len) { Connection* c = FindConnection(connection); DCHECK(c); Request& r = c->request(); std::string str(data, len); r.OnDataReceived(str); if (r.AllContentReceived()) { const Request& request = c->request(); Response* response = FindResponse(request); if (response) { std::string headers; if (!response->GetCustomHeaders(&headers)) { std::string content_type; if (!response->GetContentType(&content_type)) content_type = kDefaultContentType; headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk, content_type.c_str(), response->ContentLength()); } connection->Send(headers, false); response->WriteContents(connection); response->IncrementAccessCounter(); } else { std::string payload = "sorry, I can't find " + request.path(); std::string headers(base::StringPrintf(kDefaultHeaderTemplate, kStatusNotFound, kDefaultContentType, payload.length())); connection->Send(headers, false); connection->Send(payload, false); } } } void SimpleWebServer::DidClose(net::StreamListenSocket* sock) { // To keep the historical list of connections reasonably tidy, we delete // 404's when the connection ends. Connection* c = FindConnection(sock); DCHECK(c); c->OnSocketClosed(); if (!FindResponse(c->request())) { // extremely inefficient, but in one line and not that common... :) connections_.erase(std::find(connections_.begin(), connections_.end(), c)); delete c; } } HTTPTestServer::HTTPTestServer(int port, const std::wstring& address, base::FilePath root_dir) : port_(port), address_(address), root_dir_(root_dir) { net::EnsureWinsockInit(); server_ = net::TCPListenSocket::CreateAndListen(WideToUTF8(address), port, this); } HTTPTestServer::~HTTPTestServer() { } std::list<scoped_refptr<ConfigurableConnection>>::iterator HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) { ConnectionList::iterator it; // Scan through the list searching for the desired socket. Along the way, // erase any connections for which the corresponding socket has already been // forgotten about as a result of all data having been sent. for (it = connection_list_.begin(); it != connection_list_.end(); ) { ConfigurableConnection* connection = it->get(); if (connection->socket_ == NULL) { connection_list_.erase(it++); continue; } if (connection->socket_ == socket) break; ++it; } return it; } scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket( const net::StreamListenSocket* socket) { ConnectionList::iterator it = FindConnection(socket); if (it != connection_list_.end()) return *it; return NULL; } void HTTPTestServer::DidAccept(net::StreamListenSocket* server, scoped_ptr<net::StreamListenSocket> socket) { connection_list_.push_back(new ConfigurableConnection(socket.Pass())); } void HTTPTestServer::DidRead(net::StreamListenSocket* socket, const char* data, int len) { scoped_refptr<ConfigurableConnection> connection = ConnectionFromSocket(socket); if (connection) { std::string str(data, len); connection->r_.OnDataReceived(str); if (connection->r_.AllContentReceived()) { VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " " << connection->r_.path(); std::wstring path = UTF8ToWide(connection->r_.path()); if (LowerCaseEqualsASCII(connection->r_.method(), "post")) this->Post(connection, path, connection->r_); else this->Get(connection, path, connection->r_); } } } void HTTPTestServer::DidClose(net::StreamListenSocket* socket) { ConnectionList::iterator it = FindConnection(socket); if (it != connection_list_.end()) connection_list_.erase(it); } std::wstring HTTPTestServer::Resolve(const std::wstring& path) { // Remove the first '/' if needed. std::wstring stripped_path = path; if (path.size() && path[0] == L'/') stripped_path = path.substr(1); if (port_ == 80) { if (stripped_path.empty()) { return base::StringPrintf(L"http://%ls", address_.c_str()); } else { return base::StringPrintf(L"http://%ls/%ls", address_.c_str(), stripped_path.c_str()); } } else { if (stripped_path.empty()) { return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_); } else { return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_, stripped_path.c_str()); } } } void ConfigurableConnection::SendChunk() { int size = (int)data_.size(); const char* chunk_ptr = data_.c_str() + cur_pos_; int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_); socket_->Send(chunk_ptr, bytes_to_send); VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): " << base::StringPiece(chunk_ptr, bytes_to_send); cur_pos_ += bytes_to_send; if (cur_pos_ < size) { base::MessageLoop::current()->PostDelayedTask( FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), base::TimeDelta::FromMilliseconds(options_.timeout_)); } else { Close(); } } void ConfigurableConnection::Close() { socket_.reset(); } void ConfigurableConnection::Send(const std::string& headers, const std::string& content) { SendOptions options(SendOptions::IMMEDIATE, 0, 0); SendWithOptions(headers, content, options); } void ConfigurableConnection::SendWithOptions(const std::string& headers, const std::string& content, const SendOptions& options) { std::string content_length_header; if (!content.empty() && std::string::npos == headers.find("Context-Length:")) { content_length_header = base::StringPrintf("Content-Length: %u\r\n", content.size()); } // Save the options. options_ = options; if (options_.speed_ == SendOptions::IMMEDIATE) { socket_->Send(headers); socket_->Send(content_length_header, true); socket_->Send(content); // Post a task to close the socket since StreamListenSocket doesn't like // instances to go away from within its callbacks. base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&ConfigurableConnection::Close, this)); return; } if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) { socket_->Send(headers); socket_->Send(content_length_header, true); VLOG(1) << "Headers sent: " << headers << content_length_header; data_.append(content); } if (options_.speed_ == SendOptions::DELAYED) { data_ = headers; data_.append(content_length_header); data_.append("\r\n"); } base::MessageLoop::current()->PostDelayedTask( FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), base::TimeDelta::FromMilliseconds(options.timeout_)); } } // namespace test_server