普通文本  |  308行  |  9.83 KB

// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/host/gnubby_auth_handler_posix.h"

#include <unistd.h>
#include <utility>

#include "base/bind.h"
#include "base/files/file_util.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/lazy_instance.h"
#include "base/stl_util.h"
#include "base/values.h"
#include "net/socket/unix_domain_listen_socket_posix.h"
#include "remoting/base/logging.h"
#include "remoting/host/gnubby_socket.h"
#include "remoting/proto/control.pb.h"
#include "remoting/protocol/client_stub.h"

namespace remoting {

namespace {

const char kConnectionId[] = "connectionId";
const char kControlMessage[] = "control";
const char kControlOption[] = "option";
const char kDataMessage[] = "data";
const char kDataPayload[] = "data";
const char kErrorMessage[] = "error";
const char kGnubbyAuthMessage[] = "gnubby-auth";
const char kGnubbyAuthV1[] = "auth-v1";
const char kMessageType[] = "type";

// The name of the socket to listen for gnubby requests on.
base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name =
    LAZY_INSTANCE_INITIALIZER;

// STL predicate to match by a StreamListenSocket pointer.
class CompareSocket {
 public:
  explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}

  bool operator()(const std::pair<int, GnubbySocket*> element) const {
    return element.second->IsSocket(socket_);
  }

 private:
  net::StreamListenSocket* socket_;
};

// Socket authentication function that only allows connections from callers with
// the current uid.
bool MatchUid(const net::UnixDomainServerSocket::Credentials& credentials) {
  bool allowed = credentials.user_id == getuid();
  if (!allowed)
    HOST_LOG << "Refused socket connection from uid " << credentials.user_id;
  return allowed;
}

// Returns the command code (the first byte of the data) if it exists, or -1 if
// the data is empty.
unsigned int GetCommandCode(const std::string& data) {
  return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
}

// Creates a string of byte data from a ListValue of numbers. Returns true if
// all of the list elements are numbers.
bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
  out->clear();

  unsigned int byte_count = bytes->GetSize();
  if (byte_count != 0) {
    out->reserve(byte_count);
    for (unsigned int i = 0; i < byte_count; i++) {
      int value;
      if (!bytes->GetInteger(i, &value))
        return false;
      out->push_back(static_cast<char>(value));
    }
  }
  return true;
}

}  // namespace

GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
    protocol::ClientStub* client_stub)
    : client_stub_(client_stub), last_connection_id_(0) {
  DCHECK(client_stub_);
}

GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
  STLDeleteValues(&active_sockets_);
}

// static
scoped_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create(
    protocol::ClientStub* client_stub) {
  return scoped_ptr<GnubbyAuthHandler>(new GnubbyAuthHandlerPosix(client_stub));
}

// static
void GnubbyAuthHandler::SetGnubbySocketName(
    const base::FilePath& gnubby_socket_name) {
  g_gnubby_socket_name.Get() = gnubby_socket_name;
}

void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
  DCHECK(CalledOnValidThread());

  scoped_ptr<base::Value> value(base::JSONReader::Read(message));
  base::DictionaryValue* client_message;
  if (value && value->GetAsDictionary(&client_message)) {
    std::string type;
    if (!client_message->GetString(kMessageType, &type)) {
      LOG(ERROR) << "Invalid gnubby-auth message";
      return;
    }

    if (type == kControlMessage) {
      std::string option;
      if (client_message->GetString(kControlOption, &option) &&
          option == kGnubbyAuthV1) {
        CreateAuthorizationSocket();
      } else {
        LOG(ERROR) << "Invalid gnubby-auth control option";
      }
    } else if (type == kDataMessage) {
      ActiveSockets::iterator iter = GetSocketForMessage(client_message);
      if (iter != active_sockets_.end()) {
        base::ListValue* bytes;
        std::string response;
        if (client_message->GetList(kDataPayload, &bytes) &&
            ConvertListValueToString(bytes, &response)) {
          HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
          iter->second->SendResponse(response);
        } else {
          LOG(ERROR) << "Invalid gnubby data";
          SendErrorAndCloseActiveSocket(iter);
        }
      } else {
        LOG(ERROR) << "Unknown gnubby-auth data connection";
      }
    } else if (type == kErrorMessage) {
      ActiveSockets::iterator iter = GetSocketForMessage(client_message);
      if (iter != active_sockets_.end()) {
        HOST_LOG << "Sending gnubby error";
        SendErrorAndCloseActiveSocket(iter);
      } else {
        LOG(ERROR) << "Unknown gnubby-auth error connection";
      }
    } else {
      LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
    }
  }
}

void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
    int connection_id,
    const std::string& data) const {
  DCHECK(CalledOnValidThread());

  base::DictionaryValue request;
  request.SetString(kMessageType, kDataMessage);
  request.SetInteger(kConnectionId, connection_id);

  base::ListValue* bytes = new base::ListValue();
  for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
    bytes->AppendInteger(static_cast<unsigned char>(*i));
  }
  request.Set(kDataPayload, bytes);

  std::string request_json;
  if (!base::JSONWriter::Write(&request, &request_json)) {
    LOG(ERROR) << "Failed to create request json";
    return;
  }

  protocol::ExtensionMessage message;
  message.set_type(kGnubbyAuthMessage);
  message.set_data(request_json);

  client_stub_->DeliverHostMessage(message);
}

bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
    net::StreamListenSocket* socket) const {
  return std::find_if(active_sockets_.begin(),
                      active_sockets_.end(),
                      CompareSocket(socket)) != active_sockets_.end();
}

int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
    net::StreamListenSocket* socket) const {
  ActiveSockets::const_iterator iter = std::find_if(
      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
  return iter->first;
}

GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
    net::StreamListenSocket* socket) const {
  ActiveSockets::const_iterator iter = std::find_if(
      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
  return iter->second;
}

void GnubbyAuthHandlerPosix::DidAccept(
    net::StreamListenSocket* server,
    scoped_ptr<net::StreamListenSocket> socket) {
  DCHECK(CalledOnValidThread());

  int connection_id = ++last_connection_id_;
  active_sockets_[connection_id] =
      new GnubbySocket(socket.Pass(),
                       base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
                                  base::Unretained(this),
                                  connection_id));
}

void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
                                     const char* data,
                                     int len) {
  DCHECK(CalledOnValidThread());

  ActiveSockets::iterator iter = std::find_if(
      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
  if (iter != active_sockets_.end()) {
    GnubbySocket* gnubby_socket = iter->second;
    gnubby_socket->AddRequestData(data, len);
    if (gnubby_socket->IsRequestTooLarge()) {
      SendErrorAndCloseActiveSocket(iter);
    } else if (gnubby_socket->IsRequestComplete()) {
      std::string request_data;
      gnubby_socket->GetAndClearRequestData(&request_data);
      ProcessGnubbyRequest(iter->first, request_data);
    }
  } else {
    LOG(ERROR) << "Received data for unknown connection";
  }
}

void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
  DCHECK(CalledOnValidThread());

  ActiveSockets::iterator iter = std::find_if(
      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
  if (iter != active_sockets_.end()) {
    delete iter->second;
    active_sockets_.erase(iter);
  }
}

void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
  DCHECK(CalledOnValidThread());

  if (!g_gnubby_socket_name.Get().empty()) {
    // If the file already exists, a socket in use error is returned.
    base::DeleteFile(g_gnubby_socket_name.Get(), false);

    HOST_LOG << "Listening for gnubby requests on "
             << g_gnubby_socket_name.Get().value();

    auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen(
        g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
    if (!auth_socket_.get()) {
      LOG(ERROR) << "Failed to open socket for gnubby requests";
    }
  } else {
    HOST_LOG << "No gnubby socket name specified";
  }
}

void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
    int connection_id,
    const std::string& request_data) {
  HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
  DeliverHostDataMessage(connection_id, request_data);
}

GnubbyAuthHandlerPosix::ActiveSockets::iterator
GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
  int connection_id;
  if (message->GetInteger(kConnectionId, &connection_id)) {
    return active_sockets_.find(connection_id);
  }
  return active_sockets_.end();
}

void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
    const ActiveSockets::iterator& iter) {
  iter->second->SendSshError();

  delete iter->second;
  active_sockets_.erase(iter);
}

void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
  HOST_LOG << "Gnubby request timed out";
  ActiveSockets::iterator iter = active_sockets_.find(connection_id);
  if (iter != active_sockets_.end())
    SendErrorAndCloseActiveSocket(iter);
}

}  // namespace remoting