//
// Copyright (C) 2014 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

#include "shill/vpn/third_party_vpn_driver.h"

#include <algorithm>

#include <fcntl.h>
#include <unistd.h>

#include <base/posix/eintr_wrapper.h>
#include <base/strings/string_number_conversions.h>
#include <base/strings/string_split.h>
#if defined(__ANDROID__)
#include <dbus/service_constants.h>
#else
#include <chromeos/dbus/service_constants.h>
#endif  // __ANDROID__

#include "shill/connection.h"
#include "shill/control_interface.h"
#include "shill/device_info.h"
#include "shill/error.h"
#include "shill/file_io.h"
#include "shill/ipconfig.h"
#include "shill/logging.h"
#include "shill/manager.h"
#include "shill/property_accessor.h"
#include "shill/store_interface.h"
#include "shill/virtual_device.h"
#include "shill/vpn/vpn_service.h"

namespace shill {

namespace Logging {

static auto kModuleLogScope = ScopeLogger::kVPN;
static std::string ObjectID(const ThirdPartyVpnDriver* v) {
  return "(third_party_vpn_driver)";
}

}  // namespace Logging

namespace {

const int32_t kConstantMaxMtu = (1 << 16) - 1;
const int32_t kConnectTimeoutSeconds = 60*5;

std::string IPAddressFingerprint(const IPAddress& address) {
  static const std::string hex_to_bin[] = {
      "0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111",
      "1000", "1001", "1010", "1011", "1100", "1101", "1110", "1111"};
  std::string fingerprint;
  const size_t address_length = address.address().GetLength();
  const uint8_t* raw_address = address.address().GetConstData();
  for (size_t i = 0; i < address_length; ++i) {
    fingerprint += hex_to_bin[raw_address[i] >> 4];
    fingerprint += hex_to_bin[raw_address[i] & 0xf];
  }
  return fingerprint.substr(0, address.prefix());
}

}  // namespace

const VPNDriver::Property ThirdPartyVpnDriver::kProperties[] = {
  { kProviderHostProperty, 0 },
  { kProviderTypeProperty, 0 },
  { kExtensionNameProperty, 0 },
  { kConfigurationNameProperty, 0 }
};

ThirdPartyVpnDriver* ThirdPartyVpnDriver::active_client_ = nullptr;

ThirdPartyVpnDriver::ThirdPartyVpnDriver(ControlInterface* control,
                                         EventDispatcher* dispatcher,
                                         Metrics* metrics, Manager* manager,
                                         DeviceInfo* device_info)
    : VPNDriver(dispatcher, manager, kProperties, arraysize(kProperties)),
      control_(control),
      dispatcher_(dispatcher),
      metrics_(metrics),
      device_info_(device_info),
      tun_fd_(-1),
      parameters_expected_(false) {
  file_io_ = FileIO::GetInstance();
}

ThirdPartyVpnDriver::~ThirdPartyVpnDriver() {
  Cleanup(Service::kStateIdle, Service::kFailureUnknown,
          Service::kErrorDetailsNone);
}

void ThirdPartyVpnDriver::InitPropertyStore(PropertyStore* store) {
  VPNDriver::InitPropertyStore(store);
  store->RegisterDerivedString(
      kObjectPathSuffixProperty,
      StringAccessor(
          new CustomWriteOnlyAccessor<ThirdPartyVpnDriver, std::string>(
              this, &ThirdPartyVpnDriver::SetExtensionId,
              &ThirdPartyVpnDriver::ClearExtensionId, nullptr)));
}

bool ThirdPartyVpnDriver::Load(StoreInterface* storage,
                               const std::string& storage_id) {
  bool return_value = VPNDriver::Load(storage, storage_id);
  if (adaptor_interface_ == nullptr) {
    storage->GetString(storage_id, kObjectPathSuffixProperty,
                       &object_path_suffix_);
    adaptor_interface_.reset(control_->CreateThirdPartyVpnAdaptor(this));
  }
  return return_value;
}

bool ThirdPartyVpnDriver::Save(StoreInterface* storage,
                               const std::string& storage_id,
                               bool save_credentials) {
  bool return_value = VPNDriver::Save(storage, storage_id, save_credentials);
  storage->SetString(storage_id, kObjectPathSuffixProperty,
                     object_path_suffix_);
  return return_value;
}

void ThirdPartyVpnDriver::ClearExtensionId(Error* error) {
  error->Populate(Error::kNotSupported,
                  "Clearing extension id is not supported.");
}

bool ThirdPartyVpnDriver::SetExtensionId(const std::string& value,
                                         Error* error) {
  if (adaptor_interface_ == nullptr) {
    object_path_suffix_ = value;
    adaptor_interface_.reset(control_->CreateThirdPartyVpnAdaptor(this));
    return true;
  }
  error->Populate(Error::kAlreadyExists, "Extension ID is set");
  return false;
}

void ThirdPartyVpnDriver::UpdateConnectionState(
    Service::ConnectState connection_state, std::string* error_message) {
  if (active_client_ != this) {
    error_message->append("Unexpected call");
    return;
  }
  if (service_ && connection_state == Service::kStateFailure) {
    service_->SetState(connection_state);
    Cleanup(Service::kStateFailure, Service::kFailureUnknown,
            Service::kErrorDetailsNone);
  } else if (!service_ || connection_state != Service::kStateOnline) {
    // We expect "failure" and "connected" messages from the client, but we
    // only set state for these "failure" messages. "connected" message (which
    // is corresponding to kStateOnline here) will simply be ignored.
    error_message->append("Invalid argument");
  }
}

void ThirdPartyVpnDriver::SendPacket(const std::vector<uint8_t>& ip_packet,
                                     std::string* error_message) {
  if (active_client_ != this) {
    error_message->append("Unexpected call");
    return;
  } else if (tun_fd_ < 0) {
    error_message->append("Device not open");
    return;
  } else if (file_io_->Write(tun_fd_, ip_packet.data(), ip_packet.size()) !=
             static_cast<ssize_t>(ip_packet.size())) {
    error_message->append("Partial write");
    adaptor_interface_->EmitPlatformMessage(
        static_cast<uint32_t>(PlatformMessage::kError));
  }
}

void ThirdPartyVpnDriver::ProcessIp(
    const std::map<std::string, std::string>& parameters, const char* key,
    std::string* target, bool mandatory, std::string* error_message) {
  // TODO(kaliamoorthi): Add IPV6 support.
  auto it = parameters.find(key);
  if (it != parameters.end()) {
    if (IPAddress(parameters.at(key)).family() == IPAddress::kFamilyIPv4) {
      *target = parameters.at(key);
    } else {
      error_message->append(key).append(" is not a valid IP;");
    }
  } else if (mandatory) {
    error_message->append(key).append(" is missing;");
  }
}

void ThirdPartyVpnDriver::ProcessIPArray(
    const std::map<std::string, std::string>& parameters, const char* key,
    char delimiter, std::vector<std::string>* target, bool mandatory,
    std::string* error_message, std::string* warning_message) {
  std::vector<std::string> string_array;
  auto it = parameters.find(key);
  if (it != parameters.end()) {
    string_array = base::SplitString(
        parameters.at(key), std::string{delimiter}, base::TRIM_WHITESPACE,
        base::SPLIT_WANT_ALL);

    // Eliminate invalid IPs
    for (auto value = string_array.begin(); value != string_array.end();) {
      if (IPAddress(*value).family() != IPAddress::kFamilyIPv4) {
        warning_message->append(*value + " for " + key + " is invalid;");
        value = string_array.erase(value);
      } else {
        ++value;
      }
    }

    if (!string_array.empty()) {
      target->swap(string_array);
    } else {
      error_message->append(key).append(" has no valid values or is empty;");
    }
  } else if (mandatory) {
    error_message->append(key).append(" is missing;");
  }
}

void ThirdPartyVpnDriver::ProcessIPArrayCIDR(
    const std::map<std::string, std::string>& parameters, const char* key,
    char delimiter, std::vector<std::string>* target, bool mandatory,
    std::string* error_message, std::string* warning_message) {
  std::vector<std::string> string_array;
  IPAddress address(IPAddress::kFamilyIPv4);
  auto it = parameters.find(key);
  if (it != parameters.end()) {
    string_array = base::SplitString(
        parameters.at(key), std::string{delimiter}, base::TRIM_WHITESPACE,
        base::SPLIT_WANT_ALL);

    // Eliminate invalid IPs
    for (auto value = string_array.begin(); value != string_array.end();) {
      if (!address.SetAddressAndPrefixFromString(*value)) {
        warning_message->append(*value + " for " + key + " is invalid;");
        value = string_array.erase(value);
        continue;
      }
      const std::string cidr_key = IPAddressFingerprint(address);
      if (known_cidrs_.find(cidr_key) != known_cidrs_.end()) {
        warning_message->append("Duplicate entry for " + *value + " in " + key +
                                " found;");
        value = string_array.erase(value);
      } else {
        known_cidrs_.insert(cidr_key);
        ++value;
      }
    }

    if (!string_array.empty()) {
      target->swap(string_array);
    } else {
      error_message->append(key).append(" has no valid values or is empty;");
    }
  } else if (mandatory) {
    error_message->append(key).append(" is missing;");
  }
}

void ThirdPartyVpnDriver::ProcessSearchDomainArray(
    const std::map<std::string, std::string>& parameters, const char* key,
    char delimiter, std::vector<std::string>* target, bool mandatory,
    std::string* error_message) {
  std::vector<std::string> string_array;
  auto it = parameters.find(key);
  if (it != parameters.end()) {
    string_array = base::SplitString(
        parameters.at(key), std::string{delimiter}, base::TRIM_WHITESPACE,
        base::SPLIT_WANT_ALL);

    if (!string_array.empty()) {
      target->swap(string_array);
    } else {
      error_message->append(key).append(" has no valid values or is empty;");
    }
  } else if (mandatory) {
    error_message->append(key).append(" is missing;");
  }
}

void ThirdPartyVpnDriver::ProcessInt32(
    const std::map<std::string, std::string>& parameters, const char* key,
    int32_t* target, int32_t min_value, int32_t max_value, bool mandatory,
    std::string* error_message) {
  int32_t value = 0;
  auto it = parameters.find(key);
  if (it != parameters.end()) {
    if (base::StringToInt(parameters.at(key), &value) && value >= min_value &&
        value <= max_value) {
      *target = value;
    } else {
      error_message->append(key).append(" not in expected range;");
    }
  } else if (mandatory) {
    error_message->append(key).append(" is missing;");
  }
}

void ThirdPartyVpnDriver::SetParameters(
    const std::map<std::string, std::string>& parameters,
    std::string* error_message, std::string* warning_message) {
  // TODO(kaliamoorthi): Add IPV6 support.
  if (!parameters_expected_ || active_client_ != this) {
    error_message->append("Unexpected call");
    return;
  }

  ip_properties_ = IPConfig::Properties();
  ip_properties_.address_family = IPAddress::kFamilyIPv4;

  ProcessIp(parameters, kAddressParameterThirdPartyVpn, &ip_properties_.address,
            true, error_message);
  ProcessIp(parameters, kBroadcastAddressParameterThirdPartyVpn,
            &ip_properties_.broadcast_address, false, error_message);

  ip_properties_.gateway = ip_properties_.address;

  ProcessInt32(parameters, kSubnetPrefixParameterThirdPartyVpn,
               &ip_properties_.subnet_prefix, 0, 32, true, error_message);
  ProcessInt32(parameters, kMtuParameterThirdPartyVpn, &ip_properties_.mtu,
               IPConfig::kMinIPv4MTU, kConstantMaxMtu, false, error_message);

  ProcessSearchDomainArray(parameters, kDomainSearchParameterThirdPartyVpn,
                           kNonIPDelimiter, &ip_properties_.domain_search,
                           false, error_message);
  ProcessIPArray(parameters, kDnsServersParameterThirdPartyVpn, kIPDelimiter,
                 &ip_properties_.dns_servers, true, error_message,
                 warning_message);

  known_cidrs_.clear();

  ProcessIPArrayCIDR(parameters, kExclusionListParameterThirdPartyVpn,
                     kIPDelimiter, &ip_properties_.exclusion_list, true,
                     error_message, warning_message);
  if (!ip_properties_.exclusion_list.empty()) {
    // The first excluded IP is used to find the default gateway. The logic that
    // finds the default gateway does not work for default route "0.0.0.0/0".
    // Hence, this code ensures that the first IP is not default.
    IPAddress address(ip_properties_.address_family);
    address.SetAddressAndPrefixFromString(ip_properties_.exclusion_list[0]);
    if (address.IsDefault() && !address.prefix()) {
      if (ip_properties_.exclusion_list.size() > 1) {
        swap(ip_properties_.exclusion_list[0],
             ip_properties_.exclusion_list[1]);
      } else {
        // When there is only a single entry which is a default address, it can
        // be cleared since the default behavior is to not route any traffic to
        // the tunnel interface.
        ip_properties_.exclusion_list.clear();
      }
    }
  }

  std::vector<std::string> inclusion_list;
  ProcessIPArrayCIDR(parameters, kInclusionListParameterThirdPartyVpn,
                     kIPDelimiter, &inclusion_list, true, error_message,
                     warning_message);

  IPAddress ip_address(ip_properties_.address_family);
  IPConfig::Route route;
  route.gateway = ip_properties_.gateway;
  for (auto value = inclusion_list.begin(); value != inclusion_list.end();
       ++value) {
    ip_address.SetAddressAndPrefixFromString(*value);
    ip_address.IntoString(&route.host);
    IPAddress::GetAddressMaskFromPrefix(
        ip_address.family(), ip_address.prefix()).IntoString(&route.netmask);
    ip_properties_.routes.push_back(route);
  }

  if (error_message->empty()) {
    ip_properties_.user_traffic_only = true;
    ip_properties_.default_route = false;
    ip_properties_.blackhole_ipv6 = true;
    device_->SelectService(service_);
    device_->UpdateIPConfig(ip_properties_);
    device_->SetLooseRouting(true);
    StopConnectTimeout();
    parameters_expected_ = false;
  }
}

void ThirdPartyVpnDriver::OnInput(InputData* data) {
  // TODO(kaliamoorthi): This is not efficient, transfer the descriptor over to
  // chrome browser or use a pipe in between. Avoid using DBUS for packet
  // transfer.
  std::vector<uint8_t> ip_packet(data->buf, data->buf + data->len);
  adaptor_interface_->EmitPacketReceived(ip_packet);
}

void ThirdPartyVpnDriver::OnInputError(const std::string& error) {
  LOG(ERROR) << error;
  CHECK_EQ(active_client_, this);
  adaptor_interface_->EmitPlatformMessage(
      static_cast<uint32_t>(PlatformMessage::kError));
}

void ThirdPartyVpnDriver::Cleanup(Service::ConnectState state,
                                  Service::ConnectFailure failure,
                                  const std::string& error_details) {
  SLOG(this, 2) << __func__ << "(" << Service::ConnectStateToString(state)
                << ", " << error_details << ")";
  StopConnectTimeout();
  int interface_index = -1;
  if (device_) {
    interface_index = device_->interface_index();
    device_->DropConnection();
    device_->SetEnabled(false);
    device_ = nullptr;
  }
  if (interface_index >= 0) {
    device_info_->DeleteInterface(interface_index);
  }
  tunnel_interface_.clear();
  if (service_) {
    if (state == Service::kStateFailure) {
      service_->SetErrorDetails(error_details);
      service_->SetFailure(failure);
    } else {
      service_->SetState(state);
    }
    service_ = nullptr;
  }
  if (tun_fd_ > 0) {
    file_io_->Close(tun_fd_);
    tun_fd_ = -1;
  }
  io_handler_.reset();
  if (active_client_ == this) {
    adaptor_interface_->EmitPlatformMessage(
        static_cast<uint32_t>(PlatformMessage::kDisconnected));
    active_client_ = nullptr;
  }
  parameters_expected_ = false;
}

void ThirdPartyVpnDriver::Connect(const VPNServiceRefPtr& service,
                                  Error* error) {
  SLOG(this, 2) << __func__;
  CHECK(adaptor_interface_);
  CHECK(!active_client_);
  StartConnectTimeout(kConnectTimeoutSeconds);
  ip_properties_ = IPConfig::Properties();
  service_ = service;
  service_->SetState(Service::kStateConfiguring);
  if (!device_info_->CreateTunnelInterface(&tunnel_interface_)) {
    Error::PopulateAndLog(FROM_HERE, error, Error::kInternalError,
                          "Could not create tunnel interface.");
    Cleanup(Service::kStateFailure, Service::kFailureInternal,
            "Unable to create tun interface");
  }
  // Wait for the ClaimInterface callback to continue the connection process.
}

bool ThirdPartyVpnDriver::ClaimInterface(const std::string& link_name,
                                         int interface_index) {
  if (link_name != tunnel_interface_) {
    return false;
  }
  CHECK(!active_client_);

  SLOG(this, 2) << "Claiming " << link_name << " for third party VPN tunnel";

  CHECK(!device_);
  device_ = new VirtualDevice(control_, dispatcher(), metrics_, manager(),
                              link_name, interface_index, Technology::kVPN);
  device_->SetEnabled(true);

  tun_fd_ = device_info_->OpenTunnelInterface(tunnel_interface_);
  if (tun_fd_ < 0) {
    Cleanup(Service::kStateFailure, Service::kFailureInternal,
            "Unable to open tun interface");
  } else {
    io_handler_.reset(dispatcher_->CreateInputHandler(
        tun_fd_,
        base::Bind(&ThirdPartyVpnDriver::OnInput, base::Unretained(this)),
        base::Bind(&ThirdPartyVpnDriver::OnInputError,
                   base::Unretained(this))));
    active_client_ = this;
    parameters_expected_ = true;
    adaptor_interface_->EmitPlatformMessage(
        static_cast<uint32_t>(PlatformMessage::kConnected));
  }
  return true;
}

void ThirdPartyVpnDriver::Disconnect() {
  SLOG(this, 2) << __func__;
  CHECK(adaptor_interface_);
  if (active_client_ == this) {
    Cleanup(Service::kStateIdle, Service::kFailureUnknown,
            Service::kErrorDetailsNone);
  }
}

std::string ThirdPartyVpnDriver::GetProviderType() const {
  return std::string(kProviderThirdPartyVpn);
}

void ThirdPartyVpnDriver::OnConnectionDisconnected() {
  Cleanup(Service::kStateFailure, Service::kFailureInternal,
          "Underlying network disconnected.");
}

void ThirdPartyVpnDriver::OnConnectTimeout() {
  SLOG(this, 2) << __func__;
  VPNDriver::OnConnectTimeout();
  adaptor_interface_->EmitPlatformMessage(
      static_cast<uint32_t>(PlatformMessage::kError));
  Cleanup(Service::kStateFailure, Service::kFailureConnect,
          "Connection timed out");
}

}  // namespace shill