/*
 * Copyright (C) 2016 The Android Open Source Project
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *  * Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in
 *    the documentation and/or other materials provided with the
 *    distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
 * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include "tcp.h"

#include <android-base/parseint.h>
#include <android-base/stringprintf.h>

namespace tcp {

static constexpr int kProtocolVersion = 1;
static constexpr size_t kHandshakeLength = 4;
static constexpr int kHandshakeTimeoutMs = 2000;

// Extract the big-endian 8-byte message length into a 64-bit number.
static uint64_t ExtractMessageLength(const void* buffer) {
    uint64_t ret = 0;
    for (int i = 0; i < 8; ++i) {
        ret |= uint64_t{reinterpret_cast<const uint8_t*>(buffer)[i]} << (56 - i * 8);
    }
    return ret;
}

// Encode the 64-bit number into a big-endian 8-byte message length.
static void EncodeMessageLength(uint64_t length, void* buffer) {
    for (int i = 0; i < 8; ++i) {
        reinterpret_cast<uint8_t*>(buffer)[i] = length >> (56 - i * 8);
    }
}

class TcpTransport : public Transport {
  public:
    // Factory function so we can return nullptr if initialization fails.
    static std::unique_ptr<TcpTransport> NewTransport(std::unique_ptr<Socket> socket,
                                                      std::string* error);

    ~TcpTransport() override = default;

    ssize_t Read(void* data, size_t length) override;
    ssize_t Write(const void* data, size_t length) override;
    int Close() override;

  private:
    TcpTransport(std::unique_ptr<Socket> sock) : socket_(std::move(sock)) {}

    // Connects to the device and performs the initial handshake. Returns false and fills |error|
    // on failure.
    bool InitializeProtocol(std::string* error);

    std::unique_ptr<Socket> socket_;
    uint64_t message_bytes_left_ = 0;

    DISALLOW_COPY_AND_ASSIGN(TcpTransport);
};

std::unique_ptr<TcpTransport> TcpTransport::NewTransport(std::unique_ptr<Socket> socket,
                                                         std::string* error) {
    std::unique_ptr<TcpTransport> transport(new TcpTransport(std::move(socket)));

    if (!transport->InitializeProtocol(error)) {
        return nullptr;
    }

    return transport;
}

// These error strings are checked in tcp_test.cpp and should be kept in sync.
bool TcpTransport::InitializeProtocol(std::string* error) {
    std::string handshake_message(android::base::StringPrintf("FB%02d", kProtocolVersion));

    if (!socket_->Send(handshake_message.c_str(), kHandshakeLength)) {
        *error = android::base::StringPrintf("Failed to send initialization message (%s)",
                                             Socket::GetErrorMessage().c_str());
        return false;
    }

    char buffer[kHandshakeLength + 1];
    buffer[kHandshakeLength] = '\0';
    if (socket_->ReceiveAll(buffer, kHandshakeLength, kHandshakeTimeoutMs) != kHandshakeLength) {
        *error = android::base::StringPrintf(
                "No initialization message received (%s). Target may not support TCP fastboot",
                Socket::GetErrorMessage().c_str());
        return false;
    }

    if (memcmp(buffer, "FB", 2) != 0) {
        *error = "Unrecognized initialization message. Target may not support TCP fastboot";
        return false;
    }

    int version = 0;
    if (!android::base::ParseInt(buffer + 2, &version) || version < kProtocolVersion) {
        *error = android::base::StringPrintf("Unknown TCP protocol version %s (host version %02d)",
                                             buffer + 2, kProtocolVersion);
        return false;
    }

    error->clear();
    return true;
}

ssize_t TcpTransport::Read(void* data, size_t length) {
    if (socket_ == nullptr) {
        return -1;
    }

    // Unless we're mid-message, read the next 8-byte message length.
    if (message_bytes_left_ == 0) {
        char buffer[8];
        if (socket_->ReceiveAll(buffer, 8, 0) != 8) {
            Close();
            return -1;
        }
        message_bytes_left_ = ExtractMessageLength(buffer);
    }

    // Now read the message (up to |length| bytes).
    if (length > message_bytes_left_) {
        length = message_bytes_left_;
    }
    ssize_t bytes_read = socket_->ReceiveAll(data, length, 0);
    if (bytes_read == -1) {
        Close();
    } else {
        message_bytes_left_ -= bytes_read;
    }
    return bytes_read;
}

ssize_t TcpTransport::Write(const void* data, size_t length) {
    if (socket_ == nullptr) {
        return -1;
    }

    // Use multi-buffer writes for better performance.
    char header[8];
    EncodeMessageLength(length, header);
    if (!socket_->Send(std::vector<cutils_socket_buffer_t>{{header, 8}, {data, length}})) {
        Close();
        return -1;
    }

    return length;
}

int TcpTransport::Close() {
    if (socket_ == nullptr) {
        return 0;
    }

    int result = socket_->Close();
    socket_.reset();
    return result;
}

std::unique_ptr<Transport> Connect(const std::string& hostname, int port, std::string* error) {
    return internal::Connect(Socket::NewClient(Socket::Protocol::kTcp, hostname, port, error),
                             error);
}

namespace internal {

std::unique_ptr<Transport> Connect(std::unique_ptr<Socket> sock, std::string* error) {
    if (sock == nullptr) {
        // If Socket creation failed |error| is already set.
        return nullptr;
    }

    return TcpTransport::NewTransport(std::move(sock), error);
}

}  // namespace internal

}  // namespace tcp