普通文本  |  423行  |  13.48 KB

// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome_frame/test/test_server.h"

#include <windows.h>
#include <objbase.h>
#include <urlmon.h>

#include "base/bind.h"
#include "base/logging.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome_frame/test/chrome_frame_test_utils.h"
#include "net/base/winsock_init.h"
#include "net/http/http_util.h"
#include "net/socket/tcp_listen_socket.h"

namespace test_server {
const char kDefaultHeaderTemplate[] =
    "HTTP/1.1 %hs\r\n"
    "Connection: close\r\n"
    "Content-Type: %hs\r\n"
    "Content-Length: %i\r\n\r\n";
const char kStatusOk[] = "200 OK";
const char kStatusNotFound[] = "404 Not Found";
const char kDefaultContentType[] = "text/html; charset=UTF-8";

void Request::ParseHeaders(const std::string& headers) {
  DCHECK(method_.length() == 0);

  size_t pos = headers.find("\r\n");
  DCHECK(pos != std::string::npos);
  if (pos != std::string::npos) {
    headers_ = headers.substr(pos + 2);

    base::StringTokenizer tokenizer(
        headers.begin(), headers.begin() + pos, " ");
    std::string* parse[] = { &method_, &path_, &version_ };
    int field = 0;
    while (tokenizer.GetNext() && field < arraysize(parse)) {
      parse[field++]->assign(tokenizer.token_begin(),
                             tokenizer.token_end());
    }
  }

  // Check for content-length in case we're being sent some data.
  net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(),
                                    "\r\n");
  while (it.GetNext()) {
    if (LowerCaseEqualsASCII(it.name(), "content-length")) {
      int int_content_length;
      base::StringToInt(base::StringPiece(it.values_begin(),
                                          it.values_end()),
                        &int_content_length);
      content_length_ = int_content_length;
      break;
    }
  }
}

void Request::OnDataReceived(const std::string& data) {
  content_ += data;

  if (method_.length() == 0) {
    size_t index = content_.find("\r\n\r\n");
    if (index != std::string::npos) {
      // Parse the headers before returning and chop them of the
      // data buffer we've already received.
      std::string headers(content_.substr(0, index + 2));
      ParseHeaders(headers);
      content_.erase(0, index + 4);
    }
  }
}

ResponseForPath::~ResponseForPath() {
}

SimpleResponse::~SimpleResponse() {
}

bool FileResponse::GetContentType(std::string* content_type) const {
  size_t length = ContentLength();
  char buffer[4096];
  void* data = NULL;

  if (length) {
    // Create a copy of the first few bytes of the file.
    // If we try and use the mapped file directly, FindMimeFromData will crash
    // 'cause it cheats and temporarily tries to write to the buffer!
    length = std::min(arraysize(buffer), length);
    memcpy(buffer, file_->data(), length);
    data = buffer;
  }

  LPOLESTR mime_type = NULL;
  FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL,
                   FMFD_DEFAULT, &mime_type, 0);
  if (mime_type) {
    *content_type = WideToASCII(mime_type);
    ::CoTaskMemFree(mime_type);
  }

  return content_type->length() > 0;
}

void FileResponse::WriteContents(net::StreamListenSocket* socket) const {
  DCHECK(file_.get());
  if (file_.get()) {
    socket->Send(reinterpret_cast<const char*>(file_->data()),
                 file_->length(), false);
  }
}

size_t FileResponse::ContentLength() const {
  if (file_.get() == NULL) {
    file_.reset(new base::MemoryMappedFile());
    if (!file_->Initialize(file_path_)) {
      NOTREACHED();
      file_.reset();
    }
  }
  return file_.get() ? file_->length() : 0;
}

bool RedirectResponse::GetCustomHeaders(std::string* headers) const {
  *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n"
                                "Connection: close\r\n"
                                "Content-Length: 0\r\n"
                                "Content-Type: text/html\r\n"
                                "Location: %hs\r\n\r\n",
                                redirect_url_.c_str());
  return true;
}

SimpleWebServer::SimpleWebServer(int port) {
  Construct(chrome_frame_test::GetLocalIPv4Address(), port);
}

SimpleWebServer::SimpleWebServer(const std::string& address, int port) {
  Construct(address, port);
}

SimpleWebServer::~SimpleWebServer() {
  ConnectionList::const_iterator it;
  for (it = connections_.begin(); it != connections_.end(); ++it)
    delete (*it);
  connections_.clear();
}

void SimpleWebServer::Construct(const std::string& address, int port) {
  CHECK(base::MessageLoop::current())
      << "SimpleWebServer requires a message loop";
  net::EnsureWinsockInit();
  AddResponse(&quit_);
  host_ = address;
  server_ = net::TCPListenSocket::CreateAndListen(address, port, this);
  LOG_IF(DFATAL, !server_.get())
      << "Failed to create listener socket at " << address << ":" << port;
}

void SimpleWebServer::AddResponse(Response* response) {
  responses_.push_back(response);
}

void SimpleWebServer::DeleteAllResponses() {
  std::list<Response*>::const_iterator it;
  for (it = responses_.begin(); it != responses_.end(); ++it) {
    if ((*it) != &quit_)
      delete (*it);
  }
}

Response* SimpleWebServer::FindResponse(const Request& request) const {
  std::list<Response*>::const_iterator it;
  for (it = responses_.begin(); it != responses_.end(); it++) {
    Response* response = (*it);
    if (response->Matches(request)) {
      return response;
    }
  }
  return NULL;
}

Connection* SimpleWebServer::FindConnection(
    const net::StreamListenSocket* socket) const {
  ConnectionList::const_iterator it;
  for (it = connections_.begin(); it != connections_.end(); it++) {
    if ((*it)->IsSame(socket)) {
      return (*it);
    }
  }
  return NULL;
}

void SimpleWebServer::DidAccept(
    net::StreamListenSocket* server,
    scoped_ptr<net::StreamListenSocket> connection) {
  connections_.push_back(new Connection(connection.Pass()));
}

void SimpleWebServer::DidRead(net::StreamListenSocket* connection,
                              const char* data,
                              int len) {
  Connection* c = FindConnection(connection);
  DCHECK(c);
  Request& r = c->request();
  std::string str(data, len);
  r.OnDataReceived(str);
  if (r.AllContentReceived()) {
    const Request& request = c->request();
    Response* response = FindResponse(request);
    if (response) {
      std::string headers;
      if (!response->GetCustomHeaders(&headers)) {
        std::string content_type;
        if (!response->GetContentType(&content_type))
          content_type = kDefaultContentType;
        headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk,
                                     content_type.c_str(),
                                     response->ContentLength());
      }

      connection->Send(headers, false);
      response->WriteContents(connection);
      response->IncrementAccessCounter();
    } else {
      std::string payload = "sorry, I can't find " + request.path();
      std::string headers(base::StringPrintf(kDefaultHeaderTemplate,
                                             kStatusNotFound,
                                             kDefaultContentType,
                                             payload.length()));
      connection->Send(headers, false);
      connection->Send(payload, false);
    }
  }
}

void SimpleWebServer::DidClose(net::StreamListenSocket* sock) {
  // To keep the historical list of connections reasonably tidy, we delete
  // 404's when the connection ends.
  Connection* c = FindConnection(sock);
  DCHECK(c);
  c->OnSocketClosed();
  if (!FindResponse(c->request())) {
    // extremely inefficient, but in one line and not that common... :)
    connections_.erase(std::find(connections_.begin(), connections_.end(), c));
    delete c;
  }
}

HTTPTestServer::HTTPTestServer(int port, const std::wstring& address,
                               base::FilePath root_dir)
    : port_(port), address_(address), root_dir_(root_dir) {
  net::EnsureWinsockInit();
  server_ =
      net::TCPListenSocket::CreateAndListen(WideToUTF8(address), port, this);
}

HTTPTestServer::~HTTPTestServer() {
}

std::list<scoped_refptr<ConfigurableConnection>>::iterator
HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) {
  ConnectionList::iterator it;
  // Scan through the list searching for the desired socket. Along the way,
  // erase any connections for which the corresponding socket has already been
  // forgotten about as a result of all data having been sent.
  for (it = connection_list_.begin(); it != connection_list_.end(); ) {
    ConfigurableConnection* connection = it->get();
    if (connection->socket_ == NULL) {
      connection_list_.erase(it++);
      continue;
    }
    if (connection->socket_ == socket)
      break;
    ++it;
  }

  return it;
}

scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket(
    const net::StreamListenSocket* socket) {
  ConnectionList::iterator it = FindConnection(socket);
  if (it != connection_list_.end())
    return *it;
  return NULL;
}

void HTTPTestServer::DidAccept(net::StreamListenSocket* server,
                               scoped_ptr<net::StreamListenSocket> socket) {
  connection_list_.push_back(new ConfigurableConnection(socket.Pass()));
}

void HTTPTestServer::DidRead(net::StreamListenSocket* socket,
                             const char* data,
                             int len) {
  scoped_refptr<ConfigurableConnection> connection =
      ConnectionFromSocket(socket);
  if (connection) {
    std::string str(data, len);
    connection->r_.OnDataReceived(str);
    if (connection->r_.AllContentReceived()) {
      VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " "
              << connection->r_.path();
      std::wstring path = UTF8ToWide(connection->r_.path());
      if (LowerCaseEqualsASCII(connection->r_.method(), "post"))
        this->Post(connection, path, connection->r_);
      else
        this->Get(connection, path, connection->r_);
    }
  }
}

void HTTPTestServer::DidClose(net::StreamListenSocket* socket) {
  ConnectionList::iterator it = FindConnection(socket);
  if (it != connection_list_.end())
    connection_list_.erase(it);
}

std::wstring HTTPTestServer::Resolve(const std::wstring& path) {
  // Remove the first '/' if needed.
  std::wstring stripped_path = path;
  if (path.size() && path[0] == L'/')
    stripped_path = path.substr(1);

  if (port_ == 80) {
    if (stripped_path.empty()) {
      return base::StringPrintf(L"http://%ls", address_.c_str());
    } else {
      return base::StringPrintf(L"http://%ls/%ls", address_.c_str(),
                          stripped_path.c_str());
    }
  } else {
    if (stripped_path.empty()) {
      return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_);
    } else {
      return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_,
                                stripped_path.c_str());
    }
  }
}

void ConfigurableConnection::SendChunk() {
  int size = (int)data_.size();
  const char* chunk_ptr = data_.c_str() + cur_pos_;
  int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_);

  socket_->Send(chunk_ptr, bytes_to_send);
  VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): "
          << base::StringPiece(chunk_ptr, bytes_to_send);

  cur_pos_ += bytes_to_send;
  if (cur_pos_ < size) {
    base::MessageLoop::current()->PostDelayedTask(
        FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
        base::TimeDelta::FromMilliseconds(options_.timeout_));
  } else {
    Close();
  }
}

void ConfigurableConnection::Close() {
  socket_.reset();
}

void ConfigurableConnection::Send(const std::string& headers,
                                  const std::string& content) {
  SendOptions options(SendOptions::IMMEDIATE, 0, 0);
  SendWithOptions(headers, content, options);
}

void ConfigurableConnection::SendWithOptions(const std::string& headers,
                                             const std::string& content,
                                             const SendOptions& options) {
  std::string content_length_header;
  if (!content.empty() &&
      std::string::npos == headers.find("Context-Length:")) {
    content_length_header = base::StringPrintf("Content-Length: %u\r\n",
                                               content.size());
  }

  // Save the options.
  options_ = options;

  if (options_.speed_ == SendOptions::IMMEDIATE) {
    socket_->Send(headers);
    socket_->Send(content_length_header, true);
    socket_->Send(content);
    // Post a task to close the socket since StreamListenSocket doesn't like
    // instances to go away from within its callbacks.
    base::MessageLoop::current()->PostTask(
        FROM_HERE, base::Bind(&ConfigurableConnection::Close, this));

    return;
  }

  if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) {
    socket_->Send(headers);
    socket_->Send(content_length_header, true);
    VLOG(1) << "Headers sent: " << headers << content_length_header;
    data_.append(content);
  }

  if (options_.speed_ == SendOptions::DELAYED) {
    data_ = headers;
    data_.append(content_length_header);
    data_.append("\r\n");
  }

  base::MessageLoop::current()->PostDelayedTask(
      FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
      base::TimeDelta::FromMilliseconds(options.timeout_));
}

}  // namespace test_server